Skip to content

Commit d6414c5

Browse files
authored
Merge pull request #32 from pomonam/low_precision
Merge dev branch
2 parents 1bb64f5 + 8541589 commit d6414c5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+39714
-6183
lines changed

DOCUMENTATION.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,7 @@ Kronfluence computes covariance matrices for all data points.
258258
- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices.
259259
For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices
260260
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
261-
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
262-
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
261+
covariance matrices will be saved in disk. It is also helpful when using low precision.
263262
- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices.
264263
For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices
265264
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total

examples/README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ Our examples cover the following tasks:
1818

1919
<div align="center">
2020

21-
| Task | Example Datasets |
22-
|----------------------|:------------------------:|
23-
| Regression | UCI |
24-
| Image Classification | CIFAR-10 & ImageNet |
25-
| Text Classification | GLUE |
26-
| Multiple-Choice | SWAG |
27-
| Summarization | CNN/DailyMail |
28-
| Language Modeling | WikiText-2 & OpenWebText |
21+
| Task | Example Datasets |
22+
|----------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
23+
| Regression | [UCI](https://github.com/pomonam/kronfluence/tree/main/examples/uci) |
24+
| Image Classification | [CIFAR-10](https://github.com/pomonam/kronfluence/tree/main/examples/cifar) & [ImageNet](https://github.com/pomonam/kronfluence/tree/main/examples/imagenet) |
25+
| Text Classification | [GLUE](https://github.com/pomonam/kronfluence/tree/main/examples/glue) |
26+
| Multiple-Choice | [SWAG](https://github.com/pomonam/kronfluence/tree/main/examples/swag) |
27+
| Summarization | [CNN/DailyMail](https://github.com/pomonam/kronfluence/tree/main/examples/dailymail) |
28+
| Language Modeling | [WikiText-2](https://github.com/pomonam/kronfluence/tree/main/examples/wikitext) & [OpenWebText](https://github.com/pomonam/kronfluence/tree/main/examples/openwebtext) |
2929

3030
</div>
3131

examples/cifar/README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# CIFAR-10 & ResNet-9 Example
22

3-
This directory contains scripts for training ResNet-9 and computing influence scores on CIFAR-10 dataset. The pipeline is motivated from
3+
This directory contains scripts for training ResNet-9 and computing influence scores on the CIFAR-10 dataset. The pipeline is motivated from the
44
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). To get started, please install the necessary packages by running the following command:
55

66
```bash
@@ -9,7 +9,7 @@ pip install -r requirements.txt
99

1010
## Training
1111

12-
To train ResNet-9 on the CIFAR-10 dataset, run the following command:
12+
To train ResNet-9 on CIFAR-10, execute:
1313

1414
```bash
1515
python train.py --dataset_dir ./data \
@@ -35,7 +35,8 @@ python analyze.py --query_batch_size 1000 \
3535
--factor_strategy ekfac
3636
```
3737

38-
In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors):
38+
In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`.
39+
On an A100 (80GB) GPU, computation takes approximately 2 minutes, including EKFAC factor calculation:
3940

4041
```
4142
----------------------------------------------------------------------------------------------------------------------------------
@@ -57,7 +58,7 @@ In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as t
5758
----------------------------------------------------------------------------------------------------------------------------------
5859
```
5960

60-
To use AMP when computing influence scores, run:
61+
To use AMP for faster computation, add the `--use_half_precision` flag:
6162

6263
```bash
6364
python analyze.py --query_batch_size 1000 \
@@ -89,20 +90,20 @@ This reduces computation time to about 40 seconds on an A100 (80GB) GPU:
8990
----------------------------------------------------------------------------------------------------------------------------------
9091
```
9192

92-
You can run `half_precision_analysis.py` to verify that the scores computed with AMP have high correlations with those of the default configuration.
93+
Run `half_precision_analysis.py` to verify that AMP-computed scores maintain high correlations with default configuration scores.
9394

9495
<p align="center">
9596
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Half Precision"/></a>
9697
</p>
9798

9899
## Visualizing Influential Training Images
99100

100-
[This Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing) provides a tutorial on visualizing the top influential training images.
101+
For a tutorial on visualizing top influential training images, refer to [this Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing)
101102

102103
## Mislabeled Data Detection
103104

104105
We can use self-influence scores (see **Section 5.4** for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
105-
First, train the model with 10% of the training examples mislabeled by running:
106+
First, train the model with 10% of the training examples mislabeled:
106107

107108
```bash
108109
python train.py --dataset_dir ./data \
@@ -116,7 +117,7 @@ python train.py --dataset_dir ./data \
116117
--seed 1004
117118
```
118119

119-
Then, compute the self-influence scores with:
120+
Then compute self-influence scores:
120121

121122
```bash
122123
python detect_mislabeled_dataset.py --dataset_dir ./data \
@@ -125,7 +126,7 @@ python detect_mislabeled_dataset.py --dataset_dir ./data \
125126
--factor_strategy ekfac
126127
```
127128

128-
On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence scores:
129+
On an A100 (80GB) GPU, this takes approximately 2 minutes:
129130

130131
```
131132
----------------------------------------------------------------------------------------------------------------------------------
@@ -147,7 +148,7 @@ On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence
147148
----------------------------------------------------------------------------------------------------------------------------------
148149
```
149150

150-
Around 80% of mislabeled data points can be detected by inspecting 10% of the dataset (97% by inspecting 20%).
151+
By inspecting just 10% of the dataset, about 80% of mislabeled data points can be detected (97% by inspecting 20%).
151152

152153
<p align="center">
153154
<a href="#"><img width="380" img src="figure/mislabel.png" alt="Mislabeled Data Detection"/></a>
1.16 KB
Loading

examples/cifar/half_precision_analysis.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def main():
2727
plt.rcParams["axes.axisbelow"] = True
2828

2929
# Only plot first 3000 points to avoid clutter.
30-
idx = 79
30+
idx = 0
3131
plt.scatter(half_scores[idx][:3000], scores[idx][:3000], edgecolor="k")
3232
plt.grid()
3333
plt.xlabel("bfloat16")
@@ -36,9 +36,11 @@ def main():
3636

3737
# Compute the averaged spearman correlation.
3838
all_corr = []
39-
for i in range(100):
39+
for i in range(2000):
4040
all_corr.append(spearmanr(scores[i], half_scores[i])[0])
4141
logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}")
42+
logging.info(f"Lowest Spearman Correlation: {np.array(all_corr).min()}")
43+
logging.info(f"Highest Spearman Correlation: {np.array(all_corr).max()}")
4244

4345

4446
if __name__ == "__main__":

0 commit comments

Comments
 (0)