diff --git a/research/dppl_2024/README.md b/research/dppl_2024/README.md new file mode 100644 index 00000000..0dbbdcf6 --- /dev/null +++ b/research/dppl_2024/README.md @@ -0,0 +1,87 @@ +# Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning +This folder contains the code for + +**Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning** +by Dariush Wahdany, Matthew Jagielski, Adam Dziedzic, Franziska Boenisch +https://arxiv.org/abs/2406.08039 + +Abstract: +Machine learning (ML) models have been shown to leak private information from their training datasets. Differential Privacy (DP), typically implemented through the differential private stochastic gradient descent algorithm (DP-SGD), has become the standard solution to bound leakage from the models. Despite recent improvments, DP-SGD-based approaches for private learning still usually struggle in the high privacy ($\varepsilon<0.1$) and low data regimes, and when the private training datasets are imbalanced. To overcome these limitations, we propose Differentially Private Prototype Learning (DPPL) as a new paradigm for private transfer learning. DPPL leverages publicly pre-trained encoders to extract features from private data and generates DP prototypes that represent each private class in the embedding space and can be publicly released for inference. Since our DP prototypes can be obtained from only a few private training data points and without iterative noise addition, they offer high-utility predictions and strong privacy guarantees even under the notion of pure DP. We additionally show that privacy-utility trade-offs can be further improved when leveraging the public data beyond pre-training of the encoder: we are able to privately sample our DP prototypes from the publicly available data points used to train the encoder. Our experimental evaluation with four state-of-the-art encoders, four vision datasets, and under different data and unbalancedness regimes demonstrate DPPL's high performance under strong privacy guarantees in challenging private learning setups. + + + +## Table of Contents + +- [Installation](#installation) +- [Description](#description) +- [Usage](#usage) +- [Contributing](#contributing) +- [License](#license) + +## Installation + +### Conda +```bash +conda env create -f env.yaml +``` + +### Pip +```bash +pip install -r requirements.txt +``` + +## Description + +### Imbalanced Datasets +We construct the imbalanced datasets in `lib.utils.give_imbalanced_set`. The function places an upper bound on the number of samples per class according to the minimum number of samples per class. So for an imbalance ratio of $1$, the dataset is actually balanced. `lib.utils.decay` implements the decay function $f(c)=N\exp{-\lambda c}$. The class indices are shuffled depending on the seed, therefore whether classes are part of the majority or minority classes is random. + +### DPPL-Mean +The implementation of **DPPL-Mean** can be found in `dppl_mean.py`. We first load the private dataset, average-pool its features and obtain imbalanced datasets as described above. +The private mean estimation occurs using the Jax-reimplementation of [*CoinPress*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/a684eceee76fc522773286a895bc8436-Abstract.html) in `lib.utils.coinpress`. + +### DPPL-Public +The implementation of **DPPL-Public** can be found in `dppl_public.py`. We first load the private dataset and obtain imbalanced datasets as described above. The scores are computed using `lib.utils.pairwise_distance`, a function returning cosine distances $\in [0,2]$. `lib.utils.scores_single` implements the score calculation for a single public sample, by substracting the distance to each private sample from $2$, clipping the result to $[d_{\min},d_{\max}]$ and normalizing it to $[0,1]$, before summing over all the private samples. In our implementation the sensitivity is therefore always $1$, but the mechanism is identical to one where the scores are not normalized to $[0,1]$ and the sensitivity is reduced instead. + +Finally, given the scores `lib.public.exponential` implements the exponential mechanism. Depending on whether the utility function is monotonic or not, we multiply the sensitivity by $2$ to achieve $\epsilon$-DP. For numerical reasons, the substract from all exponents the maximum exponent. Since this is the constant factor $\exp(-c)$ for all samples, the proportionality of the probalities and therefore the mechanism doesn't change, since the exponential mechanism is invariant to scaling of the utility function. + +### DPPL-Public Top-K +The implementation of **DPPL-Public Top-K** can be found in `dppl_public_topk.py`. We first load the private dataset and obtain imbalanced datasets as described above. The scores are computed as in [DPPL-Public](#dppl-public). Our unordered top-K selection is implemented using the efficient sampling algorithm from [Duff](http://arxiv.org/abs/2010.04235) (Prop. 5). `lib.public.give_topk_proto_idx` returns the indices of the prototypes w.r.t. to the order of C, i.e. if it returns $0$ it means the best utility, $1$ the second best and so on. To do so, the utility is sampled with `lib.public.exponential_parallel` using the exponential mechanism in parallel for all classes. The remainder of `lib.public.give_topk_proto_idx` is just to uniformly sample the remaining $K-1$ prototypes, s.t. their utility is higher than the sampled one. + +### Hyperparameters +We provide the hyperparameters for the models and datasets we used in `hparams_mean.md`, `hparams_public.md` and `hparams_public_topk.md`. + +## Usage + +Before running any of the experiments, set the path to your embeddings in `config/common.yaml`. Further options are +- Epsilon +- Imbalance Ratio +- Seed + +We provide the required embeddings as a [huggingface dataset](https://huggingface.co/datasets/lsc64/DPPL-embeddings). + +### DPPL-Mean +(Optional): In `config/mean.yaml`, change `pool` to any desired integer value. It configures the optional average pooling before the mean estimation and can improve utility especially at strict privacy budgets. + +```bash +python dppl_mean.py +``` +### DPPL-Public +(Optional): In `config/public.yaml`, change `max_score` and `min_score` to any desired values in [0,2], s.t. min_score < max_score. It defines the clipping of the scores and can improve utility especially at strict privacy budgets. + +**Required**: In `config/public.yaml`, change `dataset.public_data` to the path to your public dataset embeddings. + + +```bash +python dppl_mean.py +``` + +### DPPL-Public Top-K +(Optional): In `config/public_topk.yaml`, change `max_score` and `min_score` to any desired values in [0,2], s.t. min_score < max_score. It defines the clipping of the scores and can improve utility especially at strict privacy budgets. Also, change `k` to any integer value. It defines how many prototypes are selected per class and can improve utility especially at lower privacy regimes. + +**Required**: In `config/public_topk.yaml`, change `dataset.public_data` to the path to your public dataset embeddings. + + +```bash +python dppl_public_topk.py +``` + diff --git a/research/dppl_2024/conf/common.yaml b/research/dppl_2024/conf/common.yaml new file mode 100644 index 00000000..afe9e980 --- /dev/null +++ b/research/dppl_2024/conf/common.yaml @@ -0,0 +1,9 @@ +seed: 42 +dataset: + train_data: "embeddings/vit_h_14_cifar100_train.npy" + train_labels: "embeddings/vit_h_14_cifar100_train_targets.npy" + test_data: "embeddings/vit_h_14_cifar100_test.npy" + test_labels: "embeddings/vit_h_14_cifar100_test_targets.npy" + +imbalance_ratio: 1 +epsilon: 0.5 diff --git a/research/dppl_2024/conf/mean.yaml b/research/dppl_2024/conf/mean.yaml new file mode 100644 index 00000000..7b342957 --- /dev/null +++ b/research/dppl_2024/conf/mean.yaml @@ -0,0 +1,4 @@ +defaults: + - common + - _self_ +pool: 1 diff --git a/research/dppl_2024/conf/public.yaml b/research/dppl_2024/conf/public.yaml new file mode 100644 index 00000000..dc4cf203 --- /dev/null +++ b/research/dppl_2024/conf/public.yaml @@ -0,0 +1,9 @@ +defaults: + - common + - _self_ + +dataset: + public_data: "embeddings/vit_h_14_imagenet64.npy" + +max_score: 1.65 +min_score: 1.35 diff --git a/research/dppl_2024/conf/public_topk.yaml b/research/dppl_2024/conf/public_topk.yaml new file mode 100644 index 00000000..0b6abde7 --- /dev/null +++ b/research/dppl_2024/conf/public_topk.yaml @@ -0,0 +1,10 @@ +defaults: + - common + - _self_ + +dataset: + public_data: "embeddings/vit_h_14_imagenet64.npy" + +k: 5 +max_score: 1.65 +min_score: 1.35 diff --git a/research/dppl_2024/dppl_mean.py b/research/dppl_2024/dppl_mean.py new file mode 100644 index 00000000..9d8baca5 --- /dev/null +++ b/research/dppl_2024/dppl_mean.py @@ -0,0 +1,56 @@ +import flax.linen.pooling as pooling +import hydra +import jax +import jax.numpy as jnp +from omegaconf import DictConfig, OmegaConf + +from lib import coinpress, utils + + +@hydra.main(config_path="conf", config_name="mean", version_base=None) +def main(cfg: DictConfig): + print(OmegaConf.to_yaml(cfg)) + + x_train, y_train, x_test, y_test = utils.load_dataset(cfg) + x_train = pooling.avg_pool( + x_train.T, window_shape=(cfg.pool,), strides=(cfg.pool,) + ).T + x_test = pooling.avg_pool( + x_test.T, window_shape=(cfg.pool,), strides=(cfg.pool,) + ).T + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + x_train, y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + if cfg.epsilon < jnp.inf: + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + ps = jnp.array([5 / 64, 7 / 64, 52 / 64]) * rho + key = jax.random.key(cfg.seed) + class_keys = jax.random.split(key, len(classes)) + r = jnp.sqrt(x_imbalanced.shape[1]) + protos = jnp.stack( + [ + coinpress.private_mean_jit( + x_imbalanced[y_imbalanced == i], ps, key=class_keys[i], r=r + ) + for i in classes + ] + ) + else: + protos = jnp.stack( + [x_imbalanced[y_imbalanced == i].mean(axis=0) for i in classes] + ) + dists_test = utils.pairwise_distance(protos, x_test) + test_acc = float((dists_test.argmin(axis=0) == y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") + + +if __name__ == "__main__": + main() diff --git a/research/dppl_2024/dppl_public.py b/research/dppl_2024/dppl_public.py new file mode 100644 index 00000000..0ae029ed --- /dev/null +++ b/research/dppl_2024/dppl_public.py @@ -0,0 +1,73 @@ +import warnings + +import hydra +import jax +import jax.numpy as jnp +import numpy as np +from omegaconf import DictConfig, OmegaConf + +from lib import public, utils + + +@hydra.main(config_path="conf", config_name="public", version_base=None) +def main(cfg: DictConfig): + print(OmegaConf.to_yaml(cfg)) + + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) + print( + f"Converted settings epsilon {cfg.epsilon} to rho {rho} to \ + exponential epsilon {actual_epsilon}" + ) + + x_train, y_train, x_test, y_test = utils.load_dataset(cfg) + x_public = utils.load_public_dataset(cfg) + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + x_train, y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + try: + jax.devices("gpu") + except RuntimeError: + warnings.warn("No GPU found, falling back to CPU. This will be slow.") + scores = jnp.stack( + [ + utils.scores_multiple( + x_imbalanced[y_imbalanced == target], + x_public, + cfg.min_score, + cfg.max_score, + ) + for target in classes + ] + ) + sensitivity = 1.0 + proto_idx_per_class = [] + for target in classes: + proto_idx_per_class.append( + public.exponential( + scores=scores[target], + sensitivity=sensitivity, + epsilon=actual_epsilon, + size=1, + monotonic=True, + key=int(cfg.seed + target), + ) + ) + public_protos = x_public[np.concatenate(proto_idx_per_class)].reshape( + len(classes), x_public.shape[-1] + ) + dists_test = utils.pairwise_distance(public_protos, x_test) + test_acc = float((dists_test.argmin(axis=0) == y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") + + +if __name__ == "__main__": + main() diff --git a/research/dppl_2024/dppl_public_topk.py b/research/dppl_2024/dppl_public_topk.py new file mode 100644 index 00000000..8e4d1a31 --- /dev/null +++ b/research/dppl_2024/dppl_public_topk.py @@ -0,0 +1,85 @@ +import warnings +from functools import partial + +import hydra +import jax +import jax.numpy as jnp +from omegaconf import DictConfig, OmegaConf + +from lib import public, utils + + +@hydra.main(config_path="conf", config_name="public_topk", version_base=None) +def main(cfg: DictConfig): + print(OmegaConf.to_yaml(cfg)) + + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) + print( + f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential \ + epsilon {actual_epsilon}" + ) + + x_train, y_train, x_test, y_test = utils.load_dataset(cfg) + x_public = utils.load_public_dataset(cfg) + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + x_train, y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + try: + jax.devices("gpu") + except RuntimeError: + warnings.warn("No GPU found, falling back to CPU. This will be slow.") + scores = jnp.stack( + [ + utils.scores_multiple( + x_imbalanced[y_imbalanced == target], + x_public, + cfg.min_score, + cfg.max_score, + ) + for target in classes + ] + ) + c_idx = jnp.argsort(scores, axis=1, descending=True) + if cfg.epsilon < jnp.inf: + c = jnp.stack([scores[i, c_idx[i]] for i in range(scores.shape[0])]) + u = c - c[:, cfg.k - 1][:, jnp.newaxis] + with jax.experimental.enable_x64(): + logm = jax.vmap(partial(public.log_binom, k=cfg.k), in_axes=(0))( + jnp.arange(scores.shape[-1]) + ) + proto_idx_c = public.give_topk_proto_idx( + u, + logm, + cfg.k, + u.shape[0], + u.shape[1], + actual_epsilon, + cfg.seed, + ) + proto_idx = jnp.stack( + [ + c_idx[jnp.arange(c_idx.shape[0]), proto_idx_c[:, k_i]] + for k_i in range(cfg.k) + ] + ).T + else: + proto_idx = jnp.stack( + [c_idx[jnp.arange(c_idx.shape[0]), k_i] for k_i in range(cfg.k)] + ).T + public_protos = x_public[proto_idx.flatten()].reshape((*proto_idx.shape, -1)) + dists_test = utils.pairwise_distance(public_protos, x_test) + test_acc = float((dists_test.argmin(axis=0) == y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") + + +if __name__ == "__main__": + main() diff --git a/research/dppl_2024/env.yaml b/research/dppl_2024/env.yaml new file mode 100644 index 00000000..c3c4c39e --- /dev/null +++ b/research/dppl_2024/env.yaml @@ -0,0 +1,72 @@ +name: submission +channels: + - conda-forge +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - bzip2=1.0.8 + - ca-certificates=2024.2.2 + - ld_impl_linux-64=2.40 + - libffi=3.4.2 + - libgcc-ng=13.2.0 + - libgomp=13.2.0 + - libnsl=2.0.1 + - libsqlite=3.45.3 + - libuuid=2.38.1 + - libxcrypt=4.4.36 + - libzlib=1.2.13 + - ncurses=6.5 + - openssl=3.3.0 + - pip=24.0 + - python=3.10.14 + - readline=8.2 + - setuptools=69.5.1 + - tk=8.6.13 + - tzdata=2024a + - wheel=0.43.0 + - xz=5.2.6 + - pip: + - absl-py==2.1.0 + - antlr4-python3-runtime==4.9.3 + - chex==0.1.86 + - etils==1.7.0 + - flax==0.8.3 + - fsspec==2024.5.0 + - hydra-core==1.3.2 + - importlib-resources==6.4.0 + - jax==0.4.28 + - jax-cuda12-pjrt==0.4.28 + - jax-cuda12-plugin==0.4.28 + - jaxlib==0.4.28 + - markdown-it-py==3.0.0 + - mdurl==0.1.2 + - ml-dtypes==0.4.0 + - msgpack==1.0.8 + - nest-asyncio==1.6.0 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.4.5.8 + - nvidia-cuda-cupti-cu12==12.4.127 + - nvidia-cuda-nvcc-cu12==12.4.131 + - nvidia-cuda-nvrtc-cu12==12.4.127 + - nvidia-cuda-runtime-cu12==12.4.127 + - nvidia-cudnn-cu12==8.9.7.29 + - nvidia-cufft-cu12==11.2.1.3 + - nvidia-cusolver-cu12==11.6.1.9 + - nvidia-cusparse-cu12==12.3.1.170 + - nvidia-nccl-cu12==2.21.5 + - nvidia-nvjitlink-cu12==12.4.127 + - omegaconf==2.3.0 + - opt-einsum==3.3.0 + - optax==0.2.2 + - orbax-checkpoint==0.5.11 + - packaging==24.0 + - protobuf==5.26.1 + - pygments==2.18.0 + - pyyaml==6.0.1 + - rich==13.7.1 + - scipy==1.13.0 + - tensorstore==0.1.59 + - toolz==0.12.1 + - typing-extensions==4.11.0 + - zipp==3.18.2 +prefix: /opt/conda/envs/submission diff --git a/research/dppl_2024/hparams_mean.md b/research/dppl_2024/hparams_mean.md new file mode 100644 index 00000000..f6367c51 --- /dev/null +++ b/research/dppl_2024/hparams_mean.md @@ -0,0 +1,322 @@ +| dataset | imbalance_ratio | encoder | epsilon | pooling | +|:-----------|------------------:|:---------------------------------|----------:|----------:| +| cifar10 | 1 | dino_resnet50 | 0.1 | 1 | +| cifar10 | 1 | dino_resnet50 | 0.2 | 1 | +| cifar10 | 1 | dino_resnet50 | 1 | 1 | +| cifar10 | 1 | dino_resnet50 | 8 | 1 | +| cifar10 | 1 | vit_b_16 | 0.1 | 1 | +| cifar10 | 1 | vit_b_16 | 0.2 | 1 | +| cifar10 | 1 | vit_b_16 | 1 | 1 | +| cifar10 | 1 | vit_b_16 | 8 | 1 | +| cifar10 | 1 | vit_h_14 | 0.1 | 1 | +| cifar10 | 1 | vit_h_14 | 0.2 | 1 | +| cifar10 | 1 | vit_h_14 | 1 | 1 | +| cifar10 | 1 | vit_h_14 | 8 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar10 | 10 | dino_resnet50 | 0.1 | 20 | +| cifar10 | 10 | dino_resnet50 | 0.2 | 20 | +| cifar10 | 10 | dino_resnet50 | 1 | 1 | +| cifar10 | 10 | dino_resnet50 | 8 | 1 | +| cifar10 | 10 | vit_b_16 | 0.1 | 5 | +| cifar10 | 10 | vit_b_16 | 0.2 | 5 | +| cifar10 | 10 | vit_b_16 | 1 | 1 | +| cifar10 | 10 | vit_b_16 | 8 | 1 | +| cifar10 | 10 | vit_h_14 | 0.1 | 5 | +| cifar10 | 10 | vit_h_14 | 0.2 | 5 | +| cifar10 | 10 | vit_h_14 | 1 | 1 | +| cifar10 | 10 | vit_h_14 | 8 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar10 | 100 | dino_resnet50 | 0.1 | 5 | +| cifar10 | 100 | dino_resnet50 | 0.2 | 20 | +| cifar10 | 100 | dino_resnet50 | 1 | 1 | +| cifar10 | 100 | dino_resnet50 | 8 | 1 | +| cifar10 | 100 | vit_b_16 | 0.1 | 2 | +| cifar10 | 100 | vit_b_16 | 0.2 | 2 | +| cifar10 | 100 | vit_b_16 | 1 | 1 | +| cifar10 | 100 | vit_b_16 | 8 | 1 | +| cifar10 | 100 | vit_h_14 | 0.1 | 10 | +| cifar10 | 100 | vit_h_14 | 0.2 | 5 | +| cifar10 | 100 | vit_h_14 | 1 | 1 | +| cifar10 | 100 | vit_h_14 | 8 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar10 | 50 | dino_resnet50 | 0.1 | 20 | +| cifar10 | 50 | dino_resnet50 | 0.2 | 20 | +| cifar10 | 50 | dino_resnet50 | 1 | 1 | +| cifar10 | 50 | dino_resnet50 | 8 | 1 | +| cifar10 | 50 | vit_b_16 | 0.1 | 5 | +| cifar10 | 50 | vit_b_16 | 0.2 | 2 | +| cifar10 | 50 | vit_b_16 | 1 | 2 | +| cifar10 | 50 | vit_b_16 | 8 | 1 | +| cifar10 | 50 | vit_h_14 | 0.1 | 10 | +| cifar10 | 50 | vit_h_14 | 0.2 | 5 | +| cifar10 | 50 | vit_h_14 | 1 | 2 | +| cifar10 | 50 | vit_h_14 | 8 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar100 | 1 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 1 | dino_resnet50 | 0.2 | 5 | +| cifar100 | 1 | dino_resnet50 | 1 | 1 | +| cifar100 | 1 | dino_resnet50 | 8 | 1 | +| cifar100 | 1 | vit_b_16 | 0.1 | 5 | +| cifar100 | 1 | vit_b_16 | 0.2 | 2 | +| cifar100 | 1 | vit_b_16 | 1 | 1 | +| cifar100 | 1 | vit_b_16 | 8 | 1 | +| cifar100 | 1 | vit_h_14 | 0.1 | 20 | +| cifar100 | 1 | vit_h_14 | 0.2 | 5 | +| cifar100 | 1 | vit_h_14 | 1 | 1 | +| cifar100 | 1 | vit_h_14 | 8 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar100 | 10 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 10 | dino_resnet50 | 0.2 | 50 | +| cifar100 | 10 | dino_resnet50 | 1 | 1 | +| cifar100 | 10 | dino_resnet50 | 8 | 1 | +| cifar100 | 10 | vit_b_16 | 0.1 | 5 | +| cifar100 | 10 | vit_b_16 | 0.2 | 5 | +| cifar100 | 10 | vit_b_16 | 1 | 2 | +| cifar100 | 10 | vit_b_16 | 8 | 1 | +| cifar100 | 10 | vit_h_14 | 0.1 | 20 | +| cifar100 | 10 | vit_h_14 | 0.2 | 10 | +| cifar100 | 10 | vit_h_14 | 1 | 5 | +| cifar100 | 10 | vit_h_14 | 8 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar100 | 100 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 100 | dino_resnet50 | 0.2 | 50 | +| cifar100 | 100 | dino_resnet50 | 1 | 1 | +| cifar100 | 100 | dino_resnet50 | 8 | 1 | +| cifar100 | 100 | vit_b_16 | 0.1 | 5 | +| cifar100 | 100 | vit_b_16 | 0.2 | 5 | +| cifar100 | 100 | vit_b_16 | 1 | 5 | +| cifar100 | 100 | vit_b_16 | 8 | 2 | +| cifar100 | 100 | vit_h_14 | 0.1 | 20 | +| cifar100 | 100 | vit_h_14 | 0.2 | 10 | +| cifar100 | 100 | vit_h_14 | 1 | 10 | +| cifar100 | 100 | vit_h_14 | 8 | 2 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| cifar100 | 50 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 50 | dino_resnet50 | 0.2 | 50 | +| cifar100 | 50 | dino_resnet50 | 1 | 1 | +| cifar100 | 50 | dino_resnet50 | 8 | 1 | +| cifar100 | 50 | vit_b_16 | 0.1 | 5 | +| cifar100 | 50 | vit_b_16 | 0.2 | 5 | +| cifar100 | 50 | vit_b_16 | 1 | 2 | +| cifar100 | 50 | vit_b_16 | 8 | 1 | +| cifar100 | 50 | vit_h_14 | 0.1 | 20 | +| cifar100 | 50 | vit_h_14 | 0.2 | 10 | +| cifar100 | 50 | vit_h_14 | 1 | 10 | +| cifar100 | 50 | vit_h_14 | 8 | 2 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| flowers102 | 1 | dino_resnet50 | 0.1 | 5 | +| flowers102 | 1 | dino_resnet50 | 0.2 | 5 | +| flowers102 | 1 | dino_resnet50 | 1 | 5 | +| flowers102 | 1 | dino_resnet50 | 8 | 5 | +| flowers102 | 1 | vit_b_16 | 0.1 | 100 | +| flowers102 | 1 | vit_b_16 | 0.2 | 100 | +| flowers102 | 1 | vit_b_16 | 1 | 100 | +| flowers102 | 1 | vit_b_16 | 8 | 1 | +| flowers102 | 1 | vit_h_14 | 0.1 | 100 | +| flowers102 | 1 | vit_h_14 | 0.2 | 100 | +| flowers102 | 1 | vit_h_14 | 1 | 100 | +| flowers102 | 1 | vit_h_14 | 8 | 1 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| flowers102 | 10 | dino_resnet50 | 0.1 | 2 | +| flowers102 | 10 | dino_resnet50 | 0.2 | 2 | +| flowers102 | 10 | dino_resnet50 | 1 | 2 | +| flowers102 | 10 | dino_resnet50 | 8 | 100 | +| flowers102 | 10 | vit_b_16 | 0.1 | 100 | +| flowers102 | 10 | vit_b_16 | 0.2 | 100 | +| flowers102 | 10 | vit_b_16 | 1 | 100 | +| flowers102 | 10 | vit_b_16 | 8 | 5 | +| flowers102 | 10 | vit_h_14 | 0.1 | 1 | +| flowers102 | 10 | vit_h_14 | 0.2 | 1 | +| flowers102 | 10 | vit_h_14 | 1 | 1 | +| flowers102 | 10 | vit_h_14 | 8 | 5 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 10 | +| flowers102 | 100 | dino_resnet50 | 0.1 | 5 | +| flowers102 | 100 | dino_resnet50 | 0.2 | 5 | +| flowers102 | 100 | dino_resnet50 | 1 | 50 | +| flowers102 | 100 | dino_resnet50 | 8 | 2 | +| flowers102 | 100 | vit_b_16 | 0.1 | 20 | +| flowers102 | 100 | vit_b_16 | 0.2 | 20 | +| flowers102 | 100 | vit_b_16 | 1 | 2 | +| flowers102 | 100 | vit_b_16 | 8 | 2 | +| flowers102 | 100 | vit_h_14 | 0.1 | 50 | +| flowers102 | 100 | vit_h_14 | 0.2 | 50 | +| flowers102 | 100 | vit_h_14 | 1 | 5 | +| flowers102 | 100 | vit_h_14 | 8 | 10 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 10 | +| flowers102 | 50 | dino_resnet50 | 0.1 | 5 | +| flowers102 | 50 | dino_resnet50 | 0.2 | 20 | +| flowers102 | 50 | dino_resnet50 | 1 | 50 | +| flowers102 | 50 | dino_resnet50 | 8 | 1 | +| flowers102 | 50 | vit_b_16 | 0.1 | 5 | +| flowers102 | 50 | vit_b_16 | 0.2 | 5 | +| flowers102 | 50 | vit_b_16 | 1 | 5 | +| flowers102 | 50 | vit_b_16 | 8 | 5 | +| flowers102 | 50 | vit_h_14 | 0.1 | 50 | +| flowers102 | 50 | vit_h_14 | 0.2 | 50 | +| flowers102 | 50 | vit_h_14 | 1 | 50 | +| flowers102 | 50 | vit_h_14 | 8 | 10 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 10 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 10 | +| food101 | 1 | dino_resnet50 | 0.1 | 5 | +| food101 | 1 | dino_resnet50 | 0.2 | 2 | +| food101 | 1 | dino_resnet50 | 1 | 1 | +| food101 | 1 | dino_resnet50 | 8 | 1 | +| food101 | 1 | vit_b_16 | 0.1 | 5 | +| food101 | 1 | vit_b_16 | 0.2 | 2 | +| food101 | 1 | vit_b_16 | 1 | 1 | +| food101 | 1 | vit_b_16 | 8 | 1 | +| food101 | 1 | vit_h_14 | 0.1 | 10 | +| food101 | 1 | vit_h_14 | 0.2 | 2 | +| food101 | 1 | vit_h_14 | 1 | 1 | +| food101 | 1 | vit_h_14 | 8 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 5 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| food101 | 10 | dino_resnet50 | 0.1 | 50 | +| food101 | 10 | dino_resnet50 | 0.2 | 10 | +| food101 | 10 | dino_resnet50 | 1 | 1 | +| food101 | 10 | dino_resnet50 | 8 | 1 | +| food101 | 10 | vit_b_16 | 0.1 | 5 | +| food101 | 10 | vit_b_16 | 0.2 | 2 | +| food101 | 10 | vit_b_16 | 1 | 2 | +| food101 | 10 | vit_b_16 | 8 | 1 | +| food101 | 10 | vit_h_14 | 0.1 | 10 | +| food101 | 10 | vit_h_14 | 0.2 | 10 | +| food101 | 10 | vit_h_14 | 1 | 5 | +| food101 | 10 | vit_h_14 | 8 | 1 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| food101 | 100 | dino_resnet50 | 0.1 | 50 | +| food101 | 100 | dino_resnet50 | 0.2 | 50 | +| food101 | 100 | dino_resnet50 | 1 | 1 | +| food101 | 100 | dino_resnet50 | 8 | 1 | +| food101 | 100 | vit_b_16 | 0.1 | 5 | +| food101 | 100 | vit_b_16 | 0.2 | 2 | +| food101 | 100 | vit_b_16 | 1 | 2 | +| food101 | 100 | vit_b_16 | 8 | 1 | +| food101 | 100 | vit_h_14 | 0.1 | 10 | +| food101 | 100 | vit_h_14 | 0.2 | 10 | +| food101 | 100 | vit_h_14 | 1 | 5 | +| food101 | 100 | vit_h_14 | 8 | 2 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 5 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| food101 | 50 | dino_resnet50 | 0.1 | 50 | +| food101 | 50 | dino_resnet50 | 0.2 | 50 | +| food101 | 50 | dino_resnet50 | 1 | 1 | +| food101 | 50 | dino_resnet50 | 8 | 1 | +| food101 | 50 | vit_b_16 | 0.1 | 5 | +| food101 | 50 | vit_b_16 | 0.2 | 2 | +| food101 | 50 | vit_b_16 | 1 | 2 | +| food101 | 50 | vit_b_16 | 8 | 1 | +| food101 | 50 | vit_h_14 | 0.1 | 10 | +| food101 | 50 | vit_h_14 | 0.2 | 10 | +| food101 | 50 | vit_h_14 | 1 | 5 | +| food101 | 50 | vit_h_14 | 8 | 2 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 5 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| stl10 | 1 | dino_resnet50 | 0.1 | 10 | +| stl10 | 1 | dino_resnet50 | 0.2 | 10 | +| stl10 | 1 | dino_resnet50 | 1 | 1 | +| stl10 | 1 | dino_resnet50 | 8 | 1 | +| stl10 | 1 | vit_b_16 | 0.1 | 5 | +| stl10 | 1 | vit_b_16 | 0.2 | 2 | +| stl10 | 1 | vit_b_16 | 1 | 1 | +| stl10 | 1 | vit_b_16 | 8 | 1 | +| stl10 | 1 | vit_h_14 | 0.1 | 20 | +| stl10 | 1 | vit_h_14 | 0.2 | 5 | +| stl10 | 1 | vit_h_14 | 1 | 1 | +| stl10 | 1 | vit_h_14 | 8 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| stl10 | 10 | dino_resnet50 | 0.1 | 10 | +| stl10 | 10 | dino_resnet50 | 0.2 | 10 | +| stl10 | 10 | dino_resnet50 | 1 | 20 | +| stl10 | 10 | dino_resnet50 | 8 | 1 | +| stl10 | 10 | vit_b_16 | 0.1 | 5 | +| stl10 | 10 | vit_b_16 | 0.2 | 5 | +| stl10 | 10 | vit_b_16 | 1 | 2 | +| stl10 | 10 | vit_b_16 | 8 | 1 | +| stl10 | 10 | vit_h_14 | 0.1 | 20 | +| stl10 | 10 | vit_h_14 | 0.2 | 10 | +| stl10 | 10 | vit_h_14 | 1 | 5 | +| stl10 | 10 | vit_h_14 | 8 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| stl10 | 100 | dino_resnet50 | 0.1 | 100 | +| stl10 | 100 | dino_resnet50 | 0.2 | 10 | +| stl10 | 100 | dino_resnet50 | 1 | 50 | +| stl10 | 100 | dino_resnet50 | 8 | 1 | +| stl10 | 100 | vit_b_16 | 0.1 | 10 | +| stl10 | 100 | vit_b_16 | 0.2 | 5 | +| stl10 | 100 | vit_b_16 | 1 | 2 | +| stl10 | 100 | vit_b_16 | 8 | 1 | +| stl10 | 100 | vit_h_14 | 0.1 | 10 | +| stl10 | 100 | vit_h_14 | 0.2 | 10 | +| stl10 | 100 | vit_h_14 | 1 | 10 | +| stl10 | 100 | vit_h_14 | 8 | 5 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| stl10 | 50 | dino_resnet50 | 0.1 | 10 | +| stl10 | 50 | dino_resnet50 | 0.2 | 10 | +| stl10 | 50 | dino_resnet50 | 1 | 50 | +| stl10 | 50 | dino_resnet50 | 8 | 1 | +| stl10 | 50 | vit_b_16 | 0.1 | 10 | +| stl10 | 50 | vit_b_16 | 0.2 | 5 | +| stl10 | 50 | vit_b_16 | 1 | 1 | +| stl10 | 50 | vit_b_16 | 8 | 1 | +| stl10 | 50 | vit_h_14 | 0.1 | 10 | +| stl10 | 50 | vit_h_14 | 0.2 | 10 | +| stl10 | 50 | vit_h_14 | 1 | 10 | +| stl10 | 50 | vit_h_14 | 8 | 5 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | diff --git a/research/dppl_2024/hparams_public.md b/research/dppl_2024/hparams_public.md new file mode 100644 index 00000000..55964b70 --- /dev/null +++ b/research/dppl_2024/hparams_public.md @@ -0,0 +1,258 @@ +| dataset | imbalance_ratio | encoder | epsilon | d_max | d_min | +|:----------|------------------:|:---------------------------------|----------:|--------:|--------:| +| cifar10 | 1 | dino_resnet50 | 0.1 | 2 | 0 | +| cifar10 | 1 | dino_resnet50 | 0.2 | 2 | 0 | +| cifar10 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | +| cifar10 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | +| cifar10 | 1 | vit_b_16 | 0.1 | 2 | 0 | +| cifar10 | 1 | vit_b_16 | 0.2 | 2 | 0 | +| cifar10 | 1 | vit_b_16 | 1 | 2 | 0 | +| cifar10 | 1 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 0.1 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 0.2 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar10 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar10 | 10 | dino_resnet50 | 0.2 | 2 | 0 | +| cifar10 | 10 | dino_resnet50 | 1 | 2 | 0 | +| cifar10 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | +| cifar10 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar10 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar10 | 10 | vit_b_16 | 1 | 2 | 0 | +| cifar10 | 10 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 0.1 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 0.2 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar10 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar10 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar10 | 100 | dino_resnet50 | 1 | 2 | 0 | +| cifar10 | 100 | dino_resnet50 | 8 | 2 | 0 | +| cifar10 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar10 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar10 | 100 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar10 | 100 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar10 | 100 | vit_h_14 | 0.2 | 1.54 | 1.46 | +| cifar10 | 100 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 100 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar10 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar10 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar10 | 50 | dino_resnet50 | 1 | 2 | 0 | +| cifar10 | 50 | dino_resnet50 | 8 | 2 | 0 | +| cifar10 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar10 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar10 | 50 | vit_b_16 | 1 | 2 | 0 | +| cifar10 | 50 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 50 | vit_h_14 | 0.1 | 1.64 | 1.34 | +| cifar10 | 50 | vit_h_14 | 0.2 | 1.54 | 1.46 | +| cifar10 | 50 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 50 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 1 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 1 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 1 | dino_resnet50 | 1 | 2 | 0 | +| cifar100 | 1 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 1 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 1 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 1 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar100 | 1 | vit_b_16 | 8 | 1.5 | 1.42 | +| cifar100 | 1 | vit_h_14 | 0.1 | 1.58 | 1.54 | +| cifar100 | 1 | vit_h_14 | 0.2 | 1.6 | 1.58 | +| cifar100 | 1 | vit_h_14 | 1 | 2 | 0 | +| cifar100 | 1 | vit_h_14 | 8 | 2 | 0 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 10 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | +| cifar100 | 10 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 10 | vit_b_16 | 1 | 1.54 | 1.46 | +| cifar100 | 10 | vit_b_16 | 8 | 1.64 | 1.34 | +| cifar100 | 10 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar100 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| cifar100 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | +| cifar100 | 10 | vit_h_14 | 8 | 2 | 0 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 100 | dino_resnet50 | 1 | 1.6 | 1.58 | +| cifar100 | 100 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 100 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | +| cifar100 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| cifar100 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | +| cifar100 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 50 | dino_resnet50 | 1 | 1.6 | 1.58 | +| cifar100 | 50 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 50 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_b_16 | 8 | 2 | 0 | +| cifar100 | 50 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| cifar100 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | +| cifar100 | 50 | vit_h_14 | 8 | 1.64 | 1.34 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| food101 | 1 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | +| food101 | 1 | dino_resnet50 | 1 | 2 | 0 | +| food101 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | +| food101 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | +| food101 | 1 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 1 | vit_b_16 | 1 | 2 | 0 | +| food101 | 1 | vit_b_16 | 8 | 2 | 0 | +| food101 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | +| food101 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| food101 | 1 | vit_h_14 | 1 | 2 | 0 | +| food101 | 1 | vit_h_14 | 8 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1.56 | 1.5 | +| food101 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 10 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| food101 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | +| food101 | 10 | dino_resnet50 | 8 | 2 | 0 | +| food101 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| food101 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 10 | vit_b_16 | 1 | 2 | 0 | +| food101 | 10 | vit_b_16 | 8 | 2 | 0 | +| food101 | 10 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| food101 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| food101 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | +| food101 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | +| food101 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| food101 | 100 | dino_resnet50 | 1 | 1.6 | 1.58 | +| food101 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | +| food101 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| food101 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | +| food101 | 100 | vit_b_16 | 8 | 2 | 0 | +| food101 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| food101 | 100 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| food101 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | +| food101 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| food101 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| food101 | 50 | dino_resnet50 | 1 | 1.6 | 1.58 | +| food101 | 50 | dino_resnet50 | 8 | 2 | 0 | +| food101 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| food101 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | +| food101 | 50 | vit_b_16 | 8 | 2 | 0 | +| food101 | 50 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| food101 | 50 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| food101 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | +| food101 | 50 | vit_h_14 | 8 | 2 | 0 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 1 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 1 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 1 | dino_resnet50 | 1 | 2 | 0 | +| stl10 | 1 | dino_resnet50 | 8 | 2 | 0 | +| stl10 | 1 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 1 | vit_b_16 | 0.2 | 2 | 0 | +| stl10 | 1 | vit_b_16 | 1 | 2 | 0 | +| stl10 | 1 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | +| stl10 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| stl10 | 1 | vit_h_14 | 1 | 1.64 | 1.34 | +| stl10 | 1 | vit_h_14 | 8 | 1.64 | 1.34 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 10 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 10 | dino_resnet50 | 1 | 1.6 | 1.58 | +| stl10 | 10 | dino_resnet50 | 8 | 2 | 0 | +| stl10 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| stl10 | 10 | vit_b_16 | 1 | 2 | 0 | +| stl10 | 10 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 10 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| stl10 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| stl10 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | +| stl10 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 100 | dino_resnet50 | 1 | 1.6 | 1.58 | +| stl10 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | +| stl10 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| stl10 | 100 | vit_b_16 | 1 | 1.5 | 1.42 | +| stl10 | 100 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| stl10 | 100 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| stl10 | 100 | vit_h_14 | 1 | 1.5 | 1.42 | +| stl10 | 100 | vit_h_14 | 8 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 50 | dino_resnet50 | 1 | 1.6 | 1.58 | +| stl10 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | +| stl10 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| stl10 | 50 | vit_b_16 | 1 | 1.5 | 1.42 | +| stl10 | 50 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 50 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| stl10 | 50 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| stl10 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | +| stl10 | 50 | vit_h_14 | 8 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | diff --git a/research/dppl_2024/hparams_public_topk.md b/research/dppl_2024/hparams_public_topk.md new file mode 100644 index 00000000..c78f74a6 --- /dev/null +++ b/research/dppl_2024/hparams_public_topk.md @@ -0,0 +1,258 @@ +| dataset | imbalance_ratio | encoder | epsilon | d_max | d_min | K | +|:----------|------------------:|:---------------------------------|----------:|--------:|--------:|----:| +| cifar10 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 10 | +| cifar10 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 20 | +| cifar10 | 1 | dino_resnet50 | 1 | 2 | 0 | 10 | +| cifar10 | 1 | dino_resnet50 | 8 | 2 | 0 | 20 | +| cifar10 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 5 | +| cifar10 | 1 | vit_b_16 | 0.2 | 2 | 0 | 3 | +| cifar10 | 1 | vit_b_16 | 1 | 2 | 0 | 20 | +| cifar10 | 1 | vit_b_16 | 8 | 2 | 0 | 20 | +| cifar10 | 1 | vit_h_14 | 0.1 | 1.506 | 1.42 | 5 | +| cifar10 | 1 | vit_h_14 | 0.2 | 1.506 | 1.42 | 10 | +| cifar10 | 1 | vit_h_14 | 1 | 1.506 | 1.42 | 20 | +| cifar10 | 1 | vit_h_14 | 8 | 1.506 | 1.42 | 20 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 3 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 5 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 10 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| cifar10 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 3 | +| cifar10 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 3 | +| cifar10 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| cifar10 | 10 | dino_resnet50 | 8 | 2 | 0 | 10 | +| cifar10 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar10 | 10 | vit_b_16 | 1 | 2 | 0 | 3 | +| cifar10 | 10 | vit_b_16 | 8 | 2 | 0 | 20 | +| cifar10 | 10 | vit_h_14 | 0.1 | 1.506 | 1.42 | 3 | +| cifar10 | 10 | vit_h_14 | 0.2 | 1.506 | 1.42 | 10 | +| cifar10 | 10 | vit_h_14 | 1 | 1.506 | 1.42 | 20 | +| cifar10 | 10 | vit_h_14 | 8 | 2 | 0 | 20 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 5 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| cifar10 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 3 | +| cifar10 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar10 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar10 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| cifar10 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 5 | +| cifar10 | 100 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar10 | 100 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar10 | 100 | vit_h_14 | 1 | 1.506 | 1.42 | 3 | +| cifar10 | 100 | vit_h_14 | 8 | 1.506 | 1.42 | 5 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 5 | +| cifar10 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 3 | +| cifar10 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar10 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| cifar10 | 50 | dino_resnet50 | 8 | 2 | 0 | 3 | +| cifar10 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_b_16 | 8 | 2 | 0 | 3 | +| cifar10 | 50 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar10 | 50 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar10 | 50 | vit_h_14 | 1 | 1.506 | 1.42 | 10 | +| cifar10 | 50 | vit_h_14 | 8 | 2 | 0 | 20 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 2 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 10 | +| cifar100 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| cifar100 | 1 | dino_resnet50 | 8 | 2 | 0 | 10 | +| cifar100 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_b_16 | 1 | 1.64 | 1.34 | 5 | +| cifar100 | 1 | vit_b_16 | 8 | 2 | 0 | 10 | +| cifar100 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 1 | vit_h_14 | 1 | 1.506 | 1.42 | 10 | +| cifar100 | 1 | vit_h_14 | 8 | 1.506 | 1.42 | 20 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 3 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| cifar100 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar100 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | 5 | +| cifar100 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_b_16 | 8 | 1.64 | 1.34 | 5 | +| cifar100 | 10 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar100 | 10 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar100 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| cifar100 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | 10 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 3 | +| cifar100 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 3 | +| cifar100 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar100 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 3 | +| cifar100 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 2 | +| cifar100 | 100 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar100 | 100 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar100 | 100 | vit_h_14 | 1 | 1.54 | 1.46 | 1 | +| cifar100 | 100 | vit_h_14 | 8 | 1.54 | 1.46 | 2 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar100 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | 5 | +| cifar100 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_b_16 | 8 | 1.64 | 1.34 | 2 | +| cifar100 | 50 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar100 | 50 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar100 | 50 | vit_h_14 | 1 | 1.506 | 1.42 | 1 | +| cifar100 | 50 | vit_h_14 | 8 | 1.54 | 1.46 | 3 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 1 | +| food101 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | 10 | +| food101 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | 20 | +| food101 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_b_16 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 1 | vit_b_16 | 1 | 1.64 | 1.34 | 5 | +| food101 | 1 | vit_b_16 | 8 | 2 | 0 | 20 | +| food101 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 1 | vit_h_14 | 1 | 1.64 | 1.34 | 5 | +| food101 | 1 | vit_h_14 | 8 | 2 | 0 | 10 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 5 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| food101 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| food101 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| food101 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_b_16 | 1 | 1.64 | 1.34 | 2 | +| food101 | 10 | vit_b_16 | 8 | 1.64 | 1.34 | 5 | +| food101 | 10 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| food101 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | 5 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 2 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 5 | +| food101 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| food101 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 5 | +| food101 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 2 | +| food101 | 100 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | 2 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 2 | +| food101 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| food101 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| food101 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 2 | +| food101 | 50 | vit_b_16 | 8 | 1.64 | 1.34 | 3 | +| food101 | 50 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_h_14 | 8 | 1.64 | 1.34 | 3 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 3 | +| stl10 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| stl10 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | 20 | +| stl10 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_b_16 | 1 | 2 | 0 | 3 | +| stl10 | 1 | vit_b_16 | 8 | 2 | 0 | 10 | +| stl10 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_h_14 | 1 | 1.64 | 1.34 | 3 | +| stl10 | 1 | vit_h_14 | 8 | 2 | 0 | 20 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 2 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| stl10 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| stl10 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_b_16 | 8 | 2 | 0 | 3 | +| stl10 | 10 | vit_h_14 | 0.1 | 1.64 | 1.34 | 2 | +| stl10 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | 20 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 3 | +| stl10 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 2 | +| stl10 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_h_14 | 0.2 | 1.64 | 1.34 | 3 | +| stl10 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | 2 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 1 | +| stl10 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | 2 | +| stl10 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_b_16 | 8 | 2 | 0 | 1 | +| stl10 | 50 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_h_14 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_h_14 | 8 | 1.64 | 1.34 | 5 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 1 | diff --git a/research/dppl_2024/lib/__init__.py b/research/dppl_2024/lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/research/dppl_2024/lib/coinpress.py b/research/dppl_2024/lib/coinpress.py new file mode 100644 index 00000000..b7133bda --- /dev/null +++ b/research/dppl_2024/lib/coinpress.py @@ -0,0 +1,84 @@ +from functools import partial + +import jax +import numpy as np +from jax import numpy as jnp + + +@jax.jit +def gaussian_tailbound_jit(d, b): + return (d + 2 * (d * jnp.log(1 / b)) ** 0.5 + 2 * jnp.log(1 / b)) ** 0.5 + + +@partial(jax.jit, static_argnames=("d",)) +def multivariate_mean_step_jit(x, c, r, p, n, d, subkey): + ## Determine a good clipping threshold + gamma = gaussian_tailbound_jit(d, 0.01) + clip_thresh = jnp.minimum( + (r**2 + 2 * r * 3 + gamma**2) ** 0.5, r + gamma + ) # 3 in place of sqrt(log(2/beta)) + + ## Round each of X1,...,Xn to the nearest point in the ball B2(c,clip_thresh) + x = x - c + mag_x = jnp.linalg.norm(x, axis=1) + + outside_ball_bool = mag_x > clip_thresh + x_hat = (x.T / mag_x).T + x = jnp.where( + outside_ball_bool[:, jnp.newaxis], + c + (x_hat * clip_thresh), + x, + ) + + ## Compute sensitivity + delta = 2 * clip_thresh / n.astype(float) + sd = delta / (2 * p) ** 0.5 + + ## Add noise calibrated to sensitivity + y = sd * jax.random.normal(subkey, (d,)) + c = jnp.sum(x, axis=0) / n.astype(float) + y + r = (1 / n.astype(float) + sd**2) ** 0.5 * gaussian_tailbound_jit(d, 0.01) + return c, r + + +def multivariate_mean_iterative_jit_inner(i, val, x, ps, n, d, subkeys): + c, r = val + c, r = multivariate_mean_step_jit(x, c, r, ps[i], n, d, subkeys[i]) + return (c, r) + + +@partial(jax.jit, static_argnames=("d", "t")) +def multivariate_mean_iterative_jit(x, c, r, t, ps, n, d, key): + subkeys = jax.random.split(key, t) + init_val = c, r + (c, r) = jax.lax.fori_loop( + 0, + t, + partial( + multivariate_mean_iterative_jit_inner, + x=x, + ps=ps, + n=n, + d=d, + subkeys=subkeys, + ), + init_val, + ) + return c + + +def private_mean_jit(x, ps, key=jax.random.key(42), r=None, c=None): + if len(x.shape) != 2: + raise ValueError( + "X must be a 2D array, but received shape: {}".format(x.shape) + ) + d = x.shape[1] + if r is None: + r = np.sqrt(d) * 0.9 + if c is None: + c = np.zeros(d) + t = len(ps) + mean = multivariate_mean_iterative_jit( + x, c=c, r=r, t=t, ps=ps, n=x.shape[0], d=d, key=key + ) + return mean diff --git a/research/dppl_2024/lib/public.py b/research/dppl_2024/lib/public.py new file mode 100644 index 00000000..0742f757 --- /dev/null +++ b/research/dppl_2024/lib/public.py @@ -0,0 +1,164 @@ +from functools import partial + +import jax +import numpy as np +from jax import numpy as jnp +from jax import scipy as jsc + + +def exponential( + scores: np.ndarray, + sensitivity: float, + epsilon: float, + size: int = 1, + max_fix: bool = True, + monotonic: bool = False, + key: int = 0, +) -> np.ndarray: + """Perform exponential sampling on the scores. + + Args: + scores (np.ndarray): The scores of the elements in R. + sensitivity (float): Sensitivity of the score function w.r.t. \ + the private data. + epsilon (float): pure-differential privacy parameter. + size (int, optional): Number of independent samplings to perform (e.g. \ + for reporting avg/std of accuracy). Defaults to 1. + max_fix (bool, optional): Perform a numeric fix by multiplying all \ + probablities with exp(-max_exponent). Defaults to True. + monotonic (bool, optional): Use lower privacy bound when the score \ + function is monotonic w.r.t. to the private dataset. Defaults to False. + key (int, optional): Random key for reproducibility. Defaults to 0. + + Returns: + np.ndarray: array of indice(s) of the sampled element(s). + """ + if np.isposinf(epsilon): + max_idx = scores.argmax() + max_idx = max_idx.repeat(size) + return max_idx + + sensitivity_factor = 1 if monotonic else 2 + + # Substract maximum exponent to avoid overflow + if max_fix: + max_exponent = epsilon * scores.max() / (sensitivity_factor * sensitivity) + else: + max_exponent = 0 + # Calculate the probability for each element, based on its score + probabilities = np.exp( + epsilon * scores / (sensitivity_factor * sensitivity) - max_exponent + ) + # Normalize the probabilties so they sum to 1 + probabilities = probabilities / np.linalg.norm(probabilities, ord=1) + + # Choose an element from R based on the probabilities + rng = np.random.default_rng(key) + return rng.choice(len(scores), size, p=probabilities, replace=True) + + +@jax.jit +def log_binom(n: int, k: int) -> float: + """Calculate log(n choose k) + + Args: + n (int): n + k (int): k + + Returns: + float: log(n choose k) + """ + return ( + jsc.special.gammaln(n + 1) + - jsc.special.gammaln(k + 1) + - jsc.special.gammaln(n - k + 1) + ) + + +@partial( + jax.jit, + static_argnames=["total_rows", "total_cols"], +) +def exponential_parallel( + u: jnp.ndarray, + logm: jnp.ndarray, + total_rows: int, + total_cols: int, + epsilon: float, + key: int = 42, +) -> jnp.ndarray: + """Perform parallel exponential sampling of all classes. + + Args: + u (jnp.ndarray): Scores, shape (total_rows, total_cols) \ + = (n_classes, n_public_samples). + logm (jnp.ndarray): Log-Counts of the Utilities, \ + shape (total_cols,) = (n_public_samples,). + total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) + total_cols (int): U.shape[1] = n_public_samples. \ + (needed for jit compilation) + epsilon (float): pure-differential privacy parameter. + key (int, optional): PRNG-initialization. Defaults to 42. + + Returns: + jnp.ndarray: array of indice(s) of the sampled element(s), \ + shape (total_rows,) = (n_classes,). + """ + rng = jax.random.key(key) + choices = ( + jnp.log(jnp.log(1 / jax.random.uniform(rng, (total_rows, total_cols)))) + - logm + - epsilon * u / 2 + ).argmin(axis=-1) + return choices + + +@partial(jax.jit, static_argnames=("total_rows", "total_cols", "k")) +def give_topk_proto_idx( + u: jnp.ndarray, + logm: jnp.ndarray, + k: int, + total_rows: int, + total_cols: int, + epsilon: float, + key: int = 42, +): + """Perform the private top-k prototyping. First, perform exponential sampling + on the utilities. + Then, uniformly sample the remaining k-1 prototypes, s.t. their utility is + equal or better. + + Args: + u (jnp.ndarray): Scores, shape (total_rows, total_cols) \ + = (n_classes, n_public_samples). + logm (jnp.ndarray): Log-Counts of the Utilities, \ + shape (total_cols,) = (n_public_samples,). + k (int): Number of prototypes per class to sample. + total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) + total_cols (int): U.shape[1] = n_public_samples. \ + (needed for jit compilation) + epsilon (float): pure-differential privacy parameter. + key (int, optional): PRNG-initialization. Defaults to 42. + + Returns: + jnp.ndarray: array of indice(s) of the sampled element(s), \ + shape (total_rows, k) = (n_classes, k). + """ + choices = exponential_parallel( + u, logm, total_rows, total_cols, epsilon, key + ).astype(int) + + proto_idx_c = jnp.concatenate( + [ + jax.lax.select( + jnp.arange(total_cols)[jnp.newaxis, :].repeat(total_rows, axis=0) + < choices[:, jnp.newaxis], + -jax.random.uniform(jax.random.key(key), (total_rows, total_cols)), + jnp.stack([jnp.zeros((total_cols)) for row in jnp.arange(total_rows)]), + ).argsort(axis=-1)[:, : k - 1], + choices[:, jnp.newaxis], + ], + axis=1, + ) + + return proto_idx_c diff --git a/research/dppl_2024/lib/utils.py b/research/dppl_2024/lib/utils.py new file mode 100644 index 00000000..4da22696 --- /dev/null +++ b/research/dppl_2024/lib/utils.py @@ -0,0 +1,145 @@ +from functools import partial + +import numpy as np +from jax import jit, vmap +from jax import numpy as jnp +from omegaconf import DictConfig + + +def load_dataset(cfg: DictConfig): + x_train = np.load(cfg.dataset.train_data) + y_train = np.load(cfg.dataset.train_labels) + x_test = np.load(cfg.dataset.test_data) + y_test = np.load(cfg.dataset.test_labels) + + return x_train, y_train, x_test, y_test + + +def load_public_dataset(cfg: DictConfig): + x_public = np.load(cfg.dataset.public_data) + return x_public + + +def decay( + cls: int | np.ndarray, max_samples: int, num_classes: int, ratio: float = 10 +): + decay = -np.log(ratio) / num_classes + return np.round(max_samples * np.exp(decay * cls)).astype(int) + + +def give_imbalanced_set(x, y, imbalance_ratio: float = 10, seed: int = 42): + classes = np.unique(y) + x_classes = [x[y == i] for i in classes] + rng = np.random.default_rng(seed) + input_samples_per_class = np.asarray([(y == i).sum() for i in classes]) + + output_samples_per_class = decay( + np.linspace(0, len(classes), len(classes)), + max_samples=input_samples_per_class.min(), + num_classes=len(classes), + ratio=imbalance_ratio, + ) + rng.shuffle(output_samples_per_class) + x = np.concatenate( + [ + x_classes[i][:num_samples] + for i, num_samples in enumerate(output_samples_per_class) + ] + ) + y = np.concatenate( + [ + np.repeat(i, num_samples) + for i, num_samples in enumerate(output_samples_per_class) + ] + ) + return x, y + + +def zcdp_of_naive_epsilon(epsilon): + return epsilon**2 / 2 + + +def exponential_epsilon_of_zcdp(rho): + return np.sqrt(8 * rho) + + +@jit +def pairwise_distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Calculate 1-cosine_similarity between x and y. + + Args: + x (jnp.ndarray): x + y (jnp.ndarray): y + + Returns: + jnp.ndarray: pairwise distance(s) + """ + x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) + y = y / jnp.linalg.norm(y, axis=-1, keepdims=True) + return 1 - jnp.dot(x, y.T) + + +@jit +def scores_single( + x: jnp.ndarray, y: jnp.ndarray, min_score: float, max_score: float +) -> jnp.ndarray: + """Score Calculation for a single public sample. The score is calculated as + the sum of the clipped pairwise distances between the public sample and the + private samples. + + Args: + x (jnp.ndarray): private dataset + y (jnp.ndarray): public sample + min_score (float): minimum score (in [0,2)) + max_score (float): maximum score (in (min_score, 2]) + + Returns: + jnp.ndarray: Score of the public sample + """ + return jnp.sum( + ( + ( + jnp.clip( + 2 - vmap(pairwise_distance, in_axes=(0, None))(x, y), + min_score, + max_score, + ) + - min_score + ) + / (max_score - min_score) + ), + axis=0, + ) + + +@partial(jit, static_argnames=["batch_size_y"]) +def scores_multiple( + x: jnp.ndarray, + y: jnp.ndarray, + min_score: float = 0.0, + max_score: float = 2.0, + batch_size_y: int = 5000, +) -> jnp.ndarray: + """Perform the score calculation batched over the public samples. + + Args: + x (jnp.ndarray): private dataset + y (jnp.ndarray): public dataset + min_score (float, optional): minimum score (in [0,2)). Defaults to 0.0. + max_score (float, optional): maximum score (in (min_score, 2]). \ + Defaults to 2.0. + batch_size_y (int, optional): batch size (impacts VRAM usage). \ + Defaults to 5000. + + Returns: + jnp.ndarray: Scores of all public samples in Y + """ + return jnp.concatenate( + [ + vmap( + partial(scores_single, min_score=min_score, max_score=max_score), + in_axes=(None, 0), + )(x, y[i : min(i + batch_size_y, len(y))]) + for i in range(0, len(y), batch_size_y) + ], + ) diff --git a/research/dppl_2024/requirements.txt b/research/dppl_2024/requirements.txt new file mode 100644 index 00000000..0e83ed25 --- /dev/null +++ b/research/dppl_2024/requirements.txt @@ -0,0 +1,43 @@ +absl-py==2.1.0 +antlr4-python3-runtime==4.9.3 +chex==0.1.86 +etils==1.7.0 +flax==0.8.3 +fsspec==2024.5.0 +hydra-core==1.3.2 +importlib_resources==6.4.0 +jax==0.4.28 +jax-cuda12-pjrt==0.4.28 +jax-cuda12-plugin==0.4.28 +jaxlib==0.4.28 +markdown-it-py==3.0.0 +mdurl==0.1.2 +ml-dtypes==0.4.0 +msgpack==1.0.8 +nest-asyncio==1.6.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvcc-cu12==12.4.131 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==8.9.7.29 +nvidia-cufft-cu12==11.2.1.3 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +omegaconf==2.3.0 +opt-einsum==3.3.0 +optax==0.2.2 +orbax-checkpoint==0.5.11 +packaging==24.0 +protobuf==5.26.1 +Pygments==2.18.0 +PyYAML==6.0.1 +rich==13.7.1 +scipy==1.13.0 +tensorstore==0.1.59 +toolz==0.12.1 +typing_extensions==4.11.0 +zipp==3.18.2