From e0da3d2f35549f8c218a2edaf4b0bad94acfb3a1 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 15 Jul 2024 10:42:27 -0400 Subject: [PATCH] Change cifar10 example --- examples/cifar/README.md | 11 ++++++----- examples/cifar/analyze.py | 2 ++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/cifar/README.md b/examples/cifar/README.md index 5474295..c45ed30 100644 --- a/examples/cifar/README.md +++ b/examples/cifar/README.md @@ -1,6 +1,6 @@ # CIFAR-10 & ResNet-9 Example -This directory contains scripts for training ResNet-9 and computing influence scores on CIFAR-10 dataset. The pipeline is motivated from +This directory contains scripts for training ResNet-9 and computing influence scores on the CIFAR-10 dataset. The pipeline is motivated from the [TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). To get started, please install the necessary packages by running the following command: ```bash @@ -9,7 +9,7 @@ pip install -r requirements.txt ## Training -To train ResNet-9 on the CIFAR-10 dataset, run the following command: +To train ResNet-9 on CIFAR-10, execute: ```bash python train.py --dataset_dir ./data \ @@ -35,7 +35,8 @@ python analyze.py --query_batch_size 1000 \ --factor_strategy ekfac ``` -In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors): +In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. +On an A100 (80GB) GPU, computation takes approximately 2 minutes, including EKFAC factor calculation: ``` ---------------------------------------------------------------------------------------------------------------------------------- @@ -57,7 +58,7 @@ In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as t ---------------------------------------------------------------------------------------------------------------------------------- ``` -To use AMP when computing influence scores, run: +To use AMP for faster computation, add the `--use_half_precision` flag: ```bash python analyze.py --query_batch_size 1000 \ @@ -89,7 +90,7 @@ This reduces computation time to about 40 seconds on an A100 (80GB) GPU: ---------------------------------------------------------------------------------------------------------------------------------- ``` -You can run `half_precision_analysis.py` to verify that the scores computed with AMP have high correlations with those of the default configuration. +Run `half_precision_analysis.py` to verify that AMP-computed scores maintain high correlations with default configuration scores.

Half Precision diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index b0d7d84..69c4faf 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -163,6 +163,8 @@ def main(): scores_name = factor_args.strategy if args.use_half_precision: score_args = all_low_precision_score_arguments(dtype=torch.bfloat16) + score_args.precondition_dtype = torch.float32 + score_args.per_sample_gradient_dtype = torch.float32 scores_name += "_half" analyzer.compute_pairwise_scores( scores_name=scores_name,