Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPPL research code #572

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions research/dppl_2024/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning
This folder contains the code for

**Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning**
by Dariush Wahdany, Matthew Jagielski, Adam Dziedzic, Franziska Boenisch
https://arxiv.org/abs/2406.08039

Abstract:
Machine learning (ML) models have been shown to leak private information from their training datasets. Differential Privacy (DP), typically implemented through the differential private stochastic gradient descent algorithm (DP-SGD), has become the standard solution to bound leakage from the models. Despite recent improvments, DP-SGD-based approaches for private learning still usually struggle in the high privacy ($\varepsilon<0.1$) and low data regimes, and when the private training datasets are imbalanced. To overcome these limitations, we propose Differentially Private Prototype Learning (DPPL) as a new paradigm for private transfer learning. DPPL leverages publicly pre-trained encoders to extract features from private data and generates DP prototypes that represent each private class in the embedding space and can be publicly released for inference. Since our DP prototypes can be obtained from only a few private training data points and without iterative noise addition, they offer high-utility predictions and strong privacy guarantees even under the notion of pure DP. We additionally show that privacy-utility trade-offs can be further improved when leveraging the public data beyond pre-training of the encoder: we are able to privately sample our DP prototypes from the publicly available data points used to train the encoder. Our experimental evaluation with four state-of-the-art encoders, four vision datasets, and under different data and unbalancedness regimes demonstrate DPPL's high performance under strong privacy guarantees in challenging private learning setups.



## Table of Contents

- [Installation](#installation)
- [Description](#description)
- [Usage](#usage)
- [Contributing](#contributing)
- [License](#license)

## Installation

### Conda
```bash
conda env create -f env.yaml
```

### Pip
```bash
pip install -r requirements.txt
```

## Description

### Imbalanced Datasets
We construct the imbalanced datasets in `lib.utils.give_imbalanced_set`. The function places an upper bound on the number of samples per class according to the minimum number of samples per class. So for an imbalance ratio of $1$, the dataset is actually balanced. `lib.utils.decay` implements the decay function $f(c)=N\exp{-\lambda c}$. The class indices are shuffled depending on the seed, therefore whether classes are part of the majority or minority classes is random.

### DPPL-Mean
The implementation of **DPPL-Mean** can be found in `dppl_mean.py`. We first load the private dataset, average-pool its features and obtain imbalanced datasets as described above.
The private mean estimation occurs using the Jax-reimplementation of [*CoinPress*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/a684eceee76fc522773286a895bc8436-Abstract.html) in `lib.utils.coinpress`.

### DPPL-Public
The implementation of **DPPL-Public** can be found in `dppl_public.py`. We first load the private dataset and obtain imbalanced datasets as described above. The scores are computed using `lib.utils.pairwise_distance`, a function returning cosine distances $\in [0,2]$. `lib.utils.scores_single` implements the score calculation for a single public sample, by substracting the distance to each private sample from $2$, clipping the result to $[d_{\min},d_{\max}]$ and normalizing it to $[0,1]$, before summing over all the private samples. In our implementation the sensitivity is therefore always $1$, but the mechanism is identical to one where the scores are not normalized to $[0,1]$ and the sensitivity is reduced instead.

Finally, given the scores `lib.public.exponential` implements the exponential mechanism. Depending on whether the utility function is monotonic or not, we multiply the sensitivity by $2$ to achieve $\epsilon$-DP. For numerical reasons, the substract from all exponents the maximum exponent. Since this is the constant factor $\exp(-c)$ for all samples, the proportionality of the probalities and therefore the mechanism doesn't change, since the exponential mechanism is invariant to scaling of the utility function.

### DPPL-Public Top-K
The implementation of **DPPL-Public Top-K** can be found in `dppl_public_topk.py`. We first load the private dataset and obtain imbalanced datasets as described above. The scores are computed as in [DPPL-Public](#dppl-public). Our unordered top-K selection is implemented using the efficient sampling algorithm from [Duff](http://arxiv.org/abs/2010.04235) (Prop. 5). `lib.public.give_topk_proto_idx` returns the indices of the prototypes w.r.t. to the order of C, i.e. if it returns $0$ it means the best utility, $1$ the second best and so on. To do so, the utility is sampled with `lib.public.exponential_parallel` using the exponential mechanism in parallel for all classes. The remainder of `lib.public.give_topk_proto_idx` is just to uniformly sample the remaining $K-1$ prototypes, s.t. their utility is higher than the sampled one.

### Hyperparameters
We provide the hyperparameters for the models and datasets we used in `hparams_mean.md`, `hparams_public.md` and `hparams_public_topk.md`.

## Usage

Before running any of the experiments, set the path to your embeddings in `config/common.yaml`. Further options are
- Epsilon
- Imbalance Ratio
- Seed

We provide the required embeddings as a [huggingface dataset](https://huggingface.co/datasets/lsc64/DPPL-embeddings).

### DPPL-Mean
(Optional): In `config/mean.yaml`, change `pool` to any desired integer value. It configures the optional average pooling before the mean estimation and can improve utility especially at strict privacy budgets.

```bash
python dppl_mean.py
```
### DPPL-Public
(Optional): In `config/public.yaml`, change `max_score` and `min_score` to any desired values in [0,2], s.t. min_score < max_score. It defines the clipping of the scores and can improve utility especially at strict privacy budgets.

**Required**: In `config/public.yaml`, change `dataset.public_data` to the path to your public dataset embeddings.


```bash
python dppl_mean.py
```

### DPPL-Public Top-K
(Optional): In `config/public_topk.yaml`, change `max_score` and `min_score` to any desired values in [0,2], s.t. min_score < max_score. It defines the clipping of the scores and can improve utility especially at strict privacy budgets. Also, change `k` to any integer value. It defines how many prototypes are selected per class and can improve utility especially at lower privacy regimes.

**Required**: In `config/public_topk.yaml`, change `dataset.public_data` to the path to your public dataset embeddings.


```bash
python dppl_public_topk.py
```

9 changes: 9 additions & 0 deletions research/dppl_2024/conf/common.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
seed: 42
dataset:
train_data: "embeddings/vit_h_14_cifar100_train.npy"
train_labels: "embeddings/vit_h_14_cifar100_train_targets.npy"
test_data: "embeddings/vit_h_14_cifar100_test.npy"
test_labels: "embeddings/vit_h_14_cifar100_test_targets.npy"

imbalance_ratio: 1
epsilon: 0.5
4 changes: 4 additions & 0 deletions research/dppl_2024/conf/mean.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- common
- _self_
pool: 1
9 changes: 9 additions & 0 deletions research/dppl_2024/conf/public.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- common
- _self_

dataset:
public_data: "embeddings/vit_h_14_imagenet64.npy"

max_score: 1.65
min_score: 1.35
10 changes: 10 additions & 0 deletions research/dppl_2024/conf/public_topk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults:
- common
- _self_

dataset:
public_data: "embeddings/vit_h_14_imagenet64.npy"

k: 5
max_score: 1.65
min_score: 1.35
56 changes: 56 additions & 0 deletions research/dppl_2024/dppl_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import flax.linen.pooling as pooling
import hydra
import jax
import jax.numpy as jnp
from omegaconf import DictConfig, OmegaConf

from lib import coinpress, utils


@hydra.main(config_path="conf", config_name="mean", version_base=None)
def main(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))

x_train, y_train, x_test, y_test = utils.load_dataset(cfg)
x_train = pooling.avg_pool(
x_train.T, window_shape=(cfg.pool,), strides=(cfg.pool,)
).T
x_test = pooling.avg_pool(
x_test.T, window_shape=(cfg.pool,), strides=(cfg.pool,)
).T
x_imbalanced, y_imbalanced = utils.give_imbalanced_set(
x_train, y_train, cfg.imbalance_ratio
)
classes = jnp.unique(y_imbalanced)
if cfg.epsilon < jnp.inf:
rho = utils.zcdp_of_naive_epsilon(cfg.epsilon)
ps = jnp.array([5 / 64, 7 / 64, 52 / 64]) * rho
key = jax.random.key(cfg.seed)
class_keys = jax.random.split(key, len(classes))
r = jnp.sqrt(x_imbalanced.shape[1])
protos = jnp.stack(
[
coinpress.private_mean_jit(
x_imbalanced[y_imbalanced == i], ps, key=class_keys[i], r=r
)
for i in classes
]
)
else:
protos = jnp.stack(
[x_imbalanced[y_imbalanced == i].mean(axis=0) for i in classes]
)
dists_test = utils.pairwise_distance(protos, x_test)
test_acc = float((dists_test.argmin(axis=0) == y_test).mean())
test_acc_per_class = jnp.stack(
[
(dists_test[..., y_test == target].argmin(axis=0) == target).mean()
for target in classes
]
)
print(f"Test accuracy: {test_acc}")
print(f"Test accuracy per class: {test_acc_per_class}")


if __name__ == "__main__":
main()
73 changes: 73 additions & 0 deletions research/dppl_2024/dppl_public.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import warnings

import hydra
import jax
import jax.numpy as jnp
import numpy as np
from omegaconf import DictConfig, OmegaConf

from lib import public, utils


@hydra.main(config_path="conf", config_name="public", version_base=None)
def main(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))

rho = utils.zcdp_of_naive_epsilon(cfg.epsilon)
actual_epsilon = utils.exponential_epsilon_of_zcdp(rho)
print(
f"Converted settings epsilon {cfg.epsilon} to rho {rho} to \
exponential epsilon {actual_epsilon}"
)

x_train, y_train, x_test, y_test = utils.load_dataset(cfg)
x_public = utils.load_public_dataset(cfg)
x_imbalanced, y_imbalanced = utils.give_imbalanced_set(
x_train, y_train, cfg.imbalance_ratio
)
classes = jnp.unique(y_imbalanced)
try:
jax.devices("gpu")
except RuntimeError:
warnings.warn("No GPU found, falling back to CPU. This will be slow.")
scores = jnp.stack(
[
utils.scores_multiple(
x_imbalanced[y_imbalanced == target],
x_public,
cfg.min_score,
cfg.max_score,
)
for target in classes
]
)
sensitivity = 1.0
proto_idx_per_class = []
for target in classes:
proto_idx_per_class.append(
public.exponential(
scores=scores[target],
sensitivity=sensitivity,
epsilon=actual_epsilon,
size=1,
monotonic=True,
key=int(cfg.seed + target),
)
)
public_protos = x_public[np.concatenate(proto_idx_per_class)].reshape(
len(classes), x_public.shape[-1]
)
dists_test = utils.pairwise_distance(public_protos, x_test)
test_acc = float((dists_test.argmin(axis=0) == y_test).mean())
test_acc_per_class = jnp.stack(
[
(dists_test[..., y_test == target].argmin(axis=0) == target).mean()
for target in classes
]
)
print(f"Test accuracy: {test_acc}")
print(f"Test accuracy per class: {test_acc_per_class}")


if __name__ == "__main__":
main()
85 changes: 85 additions & 0 deletions research/dppl_2024/dppl_public_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import warnings
from functools import partial

import hydra
import jax
import jax.numpy as jnp
from omegaconf import DictConfig, OmegaConf

from lib import public, utils


@hydra.main(config_path="conf", config_name="public_topk", version_base=None)
def main(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))

rho = utils.zcdp_of_naive_epsilon(cfg.epsilon)
actual_epsilon = utils.exponential_epsilon_of_zcdp(rho)
print(
f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential \
epsilon {actual_epsilon}"
)

x_train, y_train, x_test, y_test = utils.load_dataset(cfg)
x_public = utils.load_public_dataset(cfg)
x_imbalanced, y_imbalanced = utils.give_imbalanced_set(
x_train, y_train, cfg.imbalance_ratio
)
classes = jnp.unique(y_imbalanced)
try:
jax.devices("gpu")
except RuntimeError:
warnings.warn("No GPU found, falling back to CPU. This will be slow.")
scores = jnp.stack(
[
utils.scores_multiple(
x_imbalanced[y_imbalanced == target],
x_public,
cfg.min_score,
cfg.max_score,
)
for target in classes
]
)
c_idx = jnp.argsort(scores, axis=1, descending=True)
if cfg.epsilon < jnp.inf:
c = jnp.stack([scores[i, c_idx[i]] for i in range(scores.shape[0])])
u = c - c[:, cfg.k - 1][:, jnp.newaxis]
with jax.experimental.enable_x64():
logm = jax.vmap(partial(public.log_binom, k=cfg.k), in_axes=(0))(
jnp.arange(scores.shape[-1])
)
proto_idx_c = public.give_topk_proto_idx(
u,
logm,
cfg.k,
u.shape[0],
u.shape[1],
actual_epsilon,
cfg.seed,
)
proto_idx = jnp.stack(
[
c_idx[jnp.arange(c_idx.shape[0]), proto_idx_c[:, k_i]]
for k_i in range(cfg.k)
]
).T
else:
proto_idx = jnp.stack(
[c_idx[jnp.arange(c_idx.shape[0]), k_i] for k_i in range(cfg.k)]
).T
public_protos = x_public[proto_idx.flatten()].reshape((*proto_idx.shape, -1))
dists_test = utils.pairwise_distance(public_protos, x_test)
test_acc = float((dists_test.argmin(axis=0) == y_test).mean())
test_acc_per_class = jnp.stack(
[
(dists_test[..., y_test == target].argmin(axis=0) == target).mean()
for target in classes
]
)
print(f"Test accuracy: {test_acc}")
print(f"Test accuracy per class: {test_acc_per_class}")


if __name__ == "__main__":
main()
Loading