diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml
index b844a53..f0bcb2e 100644
--- a/.github/workflows/python-test.yml
+++ b/.github/workflows/python-test.yml
@@ -28,8 +28,10 @@ jobs:
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
- pytest -vx tests/factors/test_eigens.py
+ pytest -vx tests/factors/test_eigendecompositions.py
+ pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
+ pytest -vx tests/modules/test_matmul.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index debad09..f4b18f5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -164,8 +164,8 @@ cython_debug/
# Checkpoints and influence outputs
checkpoints/
-analyses/
influence_results/
data/
+cache/
*.pth
*.pt
\ No newline at end of file
diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md
index b0f1e5c..cdc5cf1 100644
--- a/DOCUMENTATION.md
+++ b/DOCUMENTATION.md
@@ -67,7 +67,7 @@ class YourTask(Task):
) -> torch.Tensor:
# TODO: Complete this method.
- def tracked_modules(self) -> Optional[List[str]]:
+ def get_influence_tracked_modules(self) -> Optional[List[str]]:
# TODO: [Optional] Complete this method.
return None # Compute influence scores on all available modules.
@@ -89,7 +89,7 @@ model = prepare_model(model=model, task=task)
...
```
-If you have specified specific module names in `Task.tracked_modules`, `TrackedModule` will only be installed for these modules.
+If you have specified specific module names in `Task.get_influence_tracked_modules`, `TrackedModule` will only be installed for these modules.
**\[Optional\] Create a DDP and FSDP Module.**
After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or
@@ -140,7 +140,7 @@ Try rewriting the model so that it uses supported modules (as done for the `conv
Alternatively, you can create a subclass of `TrackedModule` to compute influence scores for your custom module.
If there are specific modules you would like to see supported, please submit an issue.
-**How should I write task.tracked_modules?**
+**How should I write task.get_influence_tracked_modules?**
We recommend using all supported modules for influence computations. However, if you would like to compute influence scores
on subset of the modules (e.g., influence computations only on MLP layers for transformer or influence computation only on the last layer),
inspect `model.named_modules()` to determine what modules to use. You can specify the list of module names you want to analyze.
@@ -183,7 +183,7 @@ def forward(x: torch.Tensor) -> torch.Tensor:
> [!WARNING]
> The default arguments assume the module is used only once during the forward pass.
> If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
-> `shared_parameters_exist=True` in `FactorArguments`.
+> `has_shared_parameters=True` in `FactorArguments`.
**Why are there so many arguments?**
Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments`
@@ -204,14 +204,13 @@ from kronfluence.arguments import FactorArguments
factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
- distributed_sync_steps=1000,
amp_dtype=None,
- shared_parameters_exist=False,
+ has_shared_parameters=False,
# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
- covariance_data_partition_size=1,
- covariance_module_partition_size=1,
+ covariance_data_partitions=1,
+ covariance_module_partitions=1,
activation_covariance_dtype=torch.float32,
gradient_covariance_dtype=torch.float32,
@@ -220,10 +219,10 @@ factor_args = FactorArguments(
# Settings for Lambda matrix fitting.
lambda_max_examples=100_000,
- lambda_data_partition_size=1,
- lambda_module_partition_size=1,
- lambda_iterative_aggregate=False,
- cached_activation_cpu_offload=False,
+ lambda_data_partitions=1,
+ lambda_module_partitions=1,
+ use_iterative_lambda_aggregation=False,
+ offload_activations_to_cpu=False,
per_sample_gradient_dtype=torch.float32,
lambda_dtype=torch.float32,
)
@@ -237,7 +236,7 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
-- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass.
+- `has_shared_parameters`: Specifies whether the shared parameters exist in the forward pass.
### Fitting Covariance Matrices
@@ -254,13 +253,13 @@ covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_fa
This step corresponds to **Equation 16** in the paper. You can tune:
- `covariance_max_examples`: Controls the maximum number of data points for fitting covariance matrices. Setting it to `None`,
Kronfluence computes covariance matrices for all data points.
-- `covariance_data_partition_size`: Number of data partitions to use for computing covariance matrices.
-For example, when `covariance_data_partition_size = 2`, the dataset is split into 2 chunks and covariance matrices
+- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices.
+For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
-- `covariance_module_partition_size`: Number of module partitions to use for computing covariance matrices.
-For example, when `covariance_module_partition_size = 2`, the module is split into 2 chunks and covariance matrices
+- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices.
+For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
covariance matrices cannot fit into GPU memory). However, this will require multiple iterations over the dataset and can be slow.
- `activation_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16`
@@ -271,7 +270,7 @@ or `torch.float16`.
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting covariance matrices.
2. Try using lower precision for `activation_covariance_dtype` and `gradient_covariance_dtype`.
-3. Try setting `covariance_module_partition_size > 1`.
+3. Try setting `covariance_module_partitions > 1`.
### Performing Eigendecomposition
@@ -301,22 +300,22 @@ lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor")
This corresponds to **Equation 20** in the paper. You can tune:
- `lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices.
-- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
-- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
-- `cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
-You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
-- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
+- `lambda_data_partitions`: Number of data partitions to use for computing Lambda matrices.
+- `lambda_module_partitions`: Number of module partitions to use for computing Lambda matrices.
+- `offload_activations_to_cpu`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
+You can set `offload_activations_to_cpu=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
+- `use_iterative_lambda_aggregation`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16`
or `torch.float16`.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
-or `torch.float16`. Recommended to use `torch.float32`.
+or `torch.float16`.
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
-2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.)
+2. Try setting `use_iterative_lambda_aggregation=True` or `offload_activations_to_cpu=True`. (Try out `use_iterative_lambda_aggregation=True` first.)
3. Try using lower precision for `per_sample_gradient_dtype` and `lambda_dtype`.
-4. Try using `lambda_module_partition_size > 1`.
+4. Try using `lambda_module_partitions > 1`.
### FAQs
@@ -339,21 +338,24 @@ import torch
from kronfluence.arguments import ScoreArguments
score_args = ScoreArguments(
- damping=1e-08,
- cached_activation_cpu_offload=False,
- distributed_sync_steps=1000,
+ damping_factor=1e-08,
amp_dtype=None,
+ offload_activations_to_cpu=False,
# More functionalities to compute influence scores.
- data_partition_size=1,
- module_partition_size=1,
- per_module_score=False,
+ data_partitions=1,
+ module_partitions=1,
+ compute_per_module_scores=False,
+ compute_per_token_scores=False,
use_measurement_for_self_influence=False,
+ aggregate_query_gradients=False,
+ aggregate_train_gradients=False,
# Configuration for query batching.
- query_gradient_rank=None,
+ query_gradient_low_rank=None,
+ use_full_svd=False,
query_gradient_svd_dtype=torch.float32,
- num_query_gradient_accumulations=1,
+ query_gradient_accumulation_steps=1,
# Configuration for dtype.
score_dtype=torch.float32,
@@ -362,23 +364,25 @@ score_args = ScoreArguments(
)
```
-- `damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
+- `damping_factor`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
`(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296).
-- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
-- `data_partition_size`: Number of data partitions for computing influence scores.
-- `module_partition_size`: Number of module partitions for computing influence scores.
-- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
+- `offload_activations_to_cpu`: Whether to offload cached activations to CPU.
+- `data_partitions`: Number of data partitions for computing influence scores.
+- `module_partitions`: Number of module partitions for computing influence scores.
+- `compute_per_module_scores`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.
-- - `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
-- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
+- `compute_per_token_scores`: Whether to return a per-token influence scores. Only applicable to transformer-based models.
+- `aggregate_query_gradients`: Whether to use the summed query gradient instead of per-sample query gradients.
+- `aggregate_train_gradients`: Whether to use the summed training gradient instead of per-sample training gradients.
+- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
+- `query_gradient_low_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
-- `num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with
+- `query_gradient_accumulation_steps`: Number of query gradients to accumulate over. For example, when `query_gradient_accumulation_steps=2` with
`query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
-- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
-but `torch.float32` is recommended.
+- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`.
### Computing Influence Scores
@@ -409,12 +413,12 @@ vector will correspond to `g_m^T ⋅ H^{-1} ⋅ g_l`, where `g_m` is the gradien
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
-2. Try setting `cached_activation_cpu_offload=True`.
+2. Try setting `offload_activations_to_cpu=True`.
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
4. Try using lower precision for `precondition_dtype`.
-5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
+5. Try setting `query_gradient_low_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores, not self-influence scores.
-6. Try setting `module_partition_size > 1`.
+6. Try setting `module_partitions > 1`.
### FAQs
diff --git a/README.md b/README.md
index 9711087..ec6b611 100644
--- a/README.md
+++ b/README.md
@@ -182,7 +182,7 @@ Please address any reported issues before submitting your PR.
## Acknowledgements
[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
-I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
+I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
## License
diff --git a/dev_requirements.txt b/dev_requirements.txt
index 7c16350..5d23191 100644
--- a/dev_requirements.txt
+++ b/dev_requirements.txt
@@ -4,10 +4,11 @@ accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
+scikit-learn>=1.4.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
-transformers>=4.41.2
+transformers>=4.42.0
isort==5.13.2
pylint==3.2.3
pytest==8.2.2
diff --git a/examples/README.md b/examples/README.md
index 4d58a09..328a50f 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -12,22 +12,22 @@ pip install -r requirements.txt
Alternatively, navigate to each example folder and run `pip install -r requirements.txt`.
-
## List of Tasks
Our examples cover the following tasks:
-| Task | Example datasets |
+| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
-| Image Classification | CIFAR-10 / ImageNet |
+| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
-| Language Modeling | WikiText-2 / OpenWebText |
+| Summarization | DNN/DailyMail |
+| Language Modeling | WikiText-2 & OpenWebText |
These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel).
-Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
+Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
\ No newline at end of file
diff --git a/examples/cifar/README.md b/examples/cifar/README.md
index c3cda87..5474295 100644
--- a/examples/cifar/README.md
+++ b/examples/cifar/README.md
@@ -26,7 +26,7 @@ This will train the model using the specified hyperparameters and save the train
## Computing Pairwise Influence Scores
-To compute pairwise influence scores on 2000 query data points using the `ekfac` factorization strategy, run the following command:
+To compute pairwise influence scores on 2000 query data points using the `ekfac` strategy, run the following command:
```bash
python analyze.py --query_batch_size 1000 \
@@ -41,23 +41,23 @@ In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as t
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 11 | 112.83 | 100 % |
+| Total | - | 11 | 106.38 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
-| Compute Pairwise Score | 47.989 | 1 | 47.989 | 42.532 |
-| Fit Lambda | 34.639 | 1 | 34.639 | 30.7 |
-| Fit Covariance | 21.841 | 1 | 21.841 | 19.357 |
-| Save Pairwise Score | 3.5998 | 1 | 3.5998 | 3.1905 |
-| Perform Eigendecomposition | 2.7724 | 1 | 2.7724 | 2.4572 |
-| Save Covariance | 0.85695 | 1 | 0.85695 | 0.75951 |
-| Save Eigendecomposition | 0.85628 | 1 | 0.85628 | 0.75892 |
-| Save Lambda | 0.12327 | 1 | 0.12327 | 0.10925 |
-| Load Eigendecomposition | 0.056494 | 1 | 0.056494 | 0.05007 |
-| Load All Factors | 0.048981 | 1 | 0.048981 | 0.043412 |
-| Load Covariance | 0.046798 | 1 | 0.046798 | 0.041476 |
+| Compute Pairwise Score | 46.745 | 1 | 46.745 | 43.941 |
+| Fit Lambda | 34.885 | 1 | 34.885 | 32.793 |
+| Fit Covariance | 22.538 | 1 | 22.538 | 21.187 |
+| Perform Eigendecomposition | 0.91424 | 1 | 0.91424 | 0.85941 |
+| Save Pairwise Score | 0.81219 | 1 | 0.81219 | 0.76348 |
+| Save Covariance | 0.22351 | 1 | 0.22351 | 0.21011 |
+| Save Eigendecomposition | 0.21617 | 1 | 0.21617 | 0.20321 |
+| Save Lambda | 0.031038 | 1 | 0.031038 | 0.029177 |
+| Load Eigendecomposition | 0.010442 | 1 | 0.010442 | 0.0098156 |
+| Load All Factors | 0.0026517 | 1 | 0.0026517 | 0.0024927 |
+| Load Covariance | 0.0016419 | 1 | 0.0016419 | 0.0015435 |
----------------------------------------------------------------------------------------------------------------------------------
```
-To use AMP when computing influence scores (in addition to half precision when computing influence factors and scores), run:
+To use AMP when computing influence scores, run:
```bash
python analyze.py --query_batch_size 1000 \
@@ -73,19 +73,19 @@ This reduces computation time to about 40 seconds on an A100 (80GB) GPU:
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 11 | 42.316 | 100 % |
+| Total | - | 11 | 35.965 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
-| Compute Pairwise Score | 19.565 | 1 | 19.565 | 46.235 |
-| Fit Lambda | 9.173 | 1 | 9.173 | 21.677 |
-| Fit Covariance | 7.3723 | 1 | 7.3723 | 17.422 |
-| Perform Eigendecomposition | 2.6613 | 1 | 2.6613 | 6.2891 |
-| Save Pairwise Score | 2.0156 | 1 | 2.0156 | 4.7633 |
-| Save Covariance | 0.71699 | 1 | 0.71699 | 1.6944 |
-| Save Eigendecomposition | 0.52561 | 1 | 0.52561 | 1.2421 |
-| Load Covariance | 0.15732 | 1 | 0.15732 | 0.37177 |
-| Save Lambda | 0.063394 | 1 | 0.063394 | 0.14981 |
-| Load Eigendecomposition | 0.051395 | 1 | 0.051395 | 0.12146 |
-| Load All Factors | 0.014144 | 1 | 0.014144 | 0.033425 |
+| Compute Pairwise Score | 18.012 | 1 | 18.012 | 50.082 |
+| Fit Lambda | 9.2271 | 1 | 9.2271 | 25.656 |
+| Fit Covariance | 7.134 | 1 | 7.134 | 19.836 |
+| Perform Eigendecomposition | 0.87962 | 1 | 0.87962 | 2.4457 |
+| Save Pairwise Score | 0.45432 | 1 | 0.45432 | 1.2632 |
+| Save Covariance | 0.12861 | 1 | 0.12861 | 0.35759 |
+| Save Eigendecomposition | 0.11296 | 1 | 0.11296 | 0.31407 |
+| Save Lambda | 0.010712 | 1 | 0.010712 | 0.029784 |
+| Load All Factors | 0.002736 | 1 | 0.002736 | 0.0076074 |
+| Load Covariance | 0.0016696 | 1 | 0.0016696 | 0.0046421 |
+| Load Eigendecomposition | 0.0014892 | 1 | 0.0014892 | 0.0041406 |
----------------------------------------------------------------------------------------------------------------------------------
```
@@ -131,19 +131,19 @@ On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 11 | 122.28 | 100 % |
+| Total | - | 11 | 121.85 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
-| Compute Self-Influence Score | 61.999 | 1 | 61.999 | 50.701 |
-| Fit Lambda | 34.629 | 1 | 34.629 | 28.319 |
-| Fit Covariance | 21.807 | 1 | 21.807 | 17.833 |
-| Perform Eigendecomposition | 1.8041 | 1 | 1.8041 | 1.4754 |
-| Save Covariance | 0.86378 | 1 | 0.86378 | 0.70638 |
-| Save Eigendecomposition | 0.84935 | 1 | 0.84935 | 0.69458 |
-| Save Lambda | 0.18367 | 1 | 0.18367 | 0.1502 |
-| Load Eigendecomposition | 0.052867 | 1 | 0.052867 | 0.043233 |
-| Load Covariance | 0.051723 | 1 | 0.051723 | 0.042298 |
-| Load All Factors | 0.031986 | 1 | 0.031986 | 0.026158 |
-| Save Self-Influence Score | 0.010352 | 1 | 0.010352 | 0.0084653 |
+| Compute Self-Influence Score | 62.778 | 1 | 62.778 | 51.519 |
+| Fit Lambda | 35.174 | 1 | 35.174 | 28.866 |
+| Fit Covariance | 22.582 | 1 | 22.582 | 18.532 |
+| Perform Eigendecomposition | 0.82656 | 1 | 0.82656 | 0.67832 |
+| Save Covariance | 0.2478 | 1 | 0.2478 | 0.20336 |
+| Save Eigendecomposition | 0.22042 | 1 | 0.22042 | 0.18088 |
+| Save Lambda | 0.018463 | 1 | 0.018463 | 0.015152 |
+| Load All Factors | 0.0027554 | 1 | 0.0027554 | 0.0022612 |
+| Load Covariance | 0.0016607 | 1 | 0.0016607 | 0.0013628 |
+| Load Eigendecomposition | 0.0015408 | 1 | 0.0015408 | 0.0012645 |
+| Save Self-Influence Score | 0.0010841 | 1 | 0.0010841 | 0.00088966 |
----------------------------------------------------------------------------------------------------------------------------------
```
diff --git a/examples/cifar/half_precision_analysis.py b/examples/cifar/half_precision_analysis.py
index c62c14b..3f9a41d 100644
--- a/examples/cifar/half_precision_analysis.py
+++ b/examples/cifar/half_precision_analysis.py
@@ -1,6 +1,8 @@
import logging
import matplotlib.pyplot as plt
+import numpy as np
+from scipy.stats import spearmanr
from tueplots import markers
from kronfluence.analyzer import Analyzer
@@ -25,13 +27,19 @@ def main():
plt.rcParams["axes.axisbelow"] = True
# Only plot first 3000 points to avoid clutter.
- idx = 0
+ idx = 79
plt.scatter(half_scores[idx][:3000], scores[idx][:3000], edgecolor="k")
plt.grid()
plt.xlabel("bfloat16")
plt.ylabel("float32")
plt.show()
+ # Compute the averaged spearman correlation.
+ all_corr = []
+ for i in range(100):
+ all_corr.append(spearmanr(scores[i], half_scores[i])[0])
+ logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}")
+
if __name__ == "__main__":
main()
diff --git a/examples/cifar/requirements.txt b/examples/cifar/requirements.txt
index 5a65422..a667c1f 100644
--- a/examples/cifar/requirements.txt
+++ b/examples/cifar/requirements.txt
@@ -1,4 +1,3 @@
scikit-learn
-jupyter
matplotlib
tueplots
\ No newline at end of file
diff --git a/examples/dailymail/README.md b/examples/dailymail/README.md
new file mode 100644
index 0000000..a7e5ab6
--- /dev/null
+++ b/examples/dailymail/README.md
@@ -0,0 +1,73 @@
+# CNN/DailyMail & T5 Example
+
+This directory contains scripts for fine-tuning T5 and computing influence scores on the CNN/DailyMail dataset. The pipeline is motivated from [this HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization).
+To begin, install the necessary packages:
+
+```bash
+pip install -r requirements.txt
+```
+
+## Training
+
+To fine-tune T5 on CNN/DailyMail, run the following command:
+
+```bash
+python train.py --checkpoint_dir ./checkpoints \
+ --train_batch_size 16 \
+ --eval_batch_size 32 \
+ --learning_rate 5e-05 \
+ --weight_decay 0.01 \
+ --num_train_epochs 3 \
+ --seed 1004
+```
+
+This will fine-tune the model using the specified hyperparameters and save the final checkpoint in the `./checkpoints` directory.
+
+## Computing Pairwise Influence Scores
+
+To calculate pairwise influence scores on 10 query data points using `ekfac`, run:
+
+```bash
+python analyze.py --factor_batch_size 64 \
+ --query_batch_size 10 \
+ --train_batch_size 128 \
+ --use_half_precision \
+ --checkpoint_dir ./checkpoints \
+ --factor_strategy ekfac
+```
+
+Alternative options for `factor_strategy` include `identity`, `diagonal`, and `kfac`. On an A100 (80GB), computing the pairwise scores (including EKFAC factors) takes approximately 1 hour:
+
+```
+----------------------------------------------------------------------------------------------------------------------------------
+| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Total | - | 11 | 3397.1 | 100 % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Compute Pairwise Score | 1905.6 | 1 | 1905.6 | 56.093 |
+| Fit Lambda | 747.5 | 1 | 747.5 | 22.004 |
+| Fit Covariance | 734.03 | 1 | 734.03 | 21.607 |
+| Perform Eigendecomposition | 8.4236 | 1 | 8.4236 | 0.24796 |
+| Save Eigendecomposition | 0.79164 | 1 | 0.79164 | 0.023303 |
+| Save Covariance | 0.60366 | 1 | 0.60366 | 0.01777 |
+| Save Lambda | 0.1514 | 1 | 0.1514 | 0.0044566 |
+| Load All Factors | 0.027977 | 1 | 0.027977 | 0.00082354 |
+| Save Pairwise Score | 0.01082 | 1 | 0.01082 | 0.00031851 |
+| Load Covariance | 0.010015 | 1 | 0.010015 | 0.0002948 |
+| Load Eigendecomposition | 0.0096806 | 1 | 0.0096806 | 0.00028497 |
+----------------------------------------------------------------------------------------------------------------------------------
+```
+
+## Inspecting Top Influential Sequences
+
+The `inspect_examples.py` script prints top influential sequences for a given query.
+
+```
+Query Data Example:
+ Input: summarize: (CNN)My vote for Father of the Year goes to Curt Schilling. The former Major League Baseball pitcher recently fired off a series of fastballs and mowed down a group of Twitter trolls who made the mistake of tweeting vulgar and sexually-explicit comments about Schilling's teenage daughter. The drama started, innocently enough, on February 25, when Schilling played the role of a proud father. He sent a tweet congratulating his daughter, Gabby, on being accepted to Salve Regina University, where she'll play softball. It read: "Congrats to Gabby Schilling who will pitch for the Salve Regina Seahawks next year!! — Curt Schilling (@gehrig38)" Almost immediately, responses came in from young men, complete strangers who apparently followed Schilling on Twitter. The tweets quickly went from immature, to creepy, to repugnant. Threats of rape were common. The tweets were deleted, and the accounts were closed after this story went viral. But not before Schilling captured some of the images and posted them on his blog. What was said about 17-year-old Gabby Schilling wasn't just obnoxious. It was vile and obscene. What was said wasn't just mean and ugly. It was threatening and scary. As a parent, it's the kind of thing that makes you rethink your opposition to public caning as a logical punishment for such transgressions. These misogynistic cowards may have thought they could hide in the darkness of anonymity, the sort that many have come to expect from social media sites, where you feel free to be a despicable human being because, you think, no one will ever find out who you really are and hold you accountable for your words. If so, they thought wrong. They couldn't hide. They were found out, and they got the throttling they so richly deserved. Thanks to dad. According to Schilling, who made it his mission to track down these cretins and make sure those they associate with know who they really are, two people have already paid a price due to their tweets. One was a student disc jockey at a community college in New Jersey, who was suspended, and the other was a part-time ticket seller for the New York Yankees, who was fired. Concerned that this is an example of exactly the kind of cyberbullying that leads some teenagers to commit suicide, Schilling is also thinking about taking legal action against some of the other people involved. Bravo for him. I'm sure that, all across America, dads with daughters -- after reading some of the horrible things that were said about this young girl -- are marveling at Schilling's self-control. I have two daughters of my own, and he's a better man than me. If ever there was a case where profanity-spewing malcontents deserved to have their mouths washed out with soap, this is it. So what additional insights can we draw, and what larger lessons can we learn, from this unexpected but predictable collision of old-fashioned parenthood and newfangled media? There are a few. The first is about accountability, the very thing that the young men who posted these hurtful messages were trying to avoid. But Schilling wouldn't let them. At their best, social media sites like Twitter, Facebook, Instagram and others allow the sharing the information and the building of a sense of community. At their worst, they become digital sandboxes and locker rooms where people think have a license to misbehave without having to worry about consequences. We need to applaud efforts like this that promote greater online accountability. There's also something to be said about protective parents, and how essential they are to a working society. We should still be concerned about those overprotective parents who hover like helicopters from little league to job interviews. We shouldn't bubblewrap our kids, and keep them from playing outdoors, and then sit around wondering why they're soft, timid, and risk-averse. But protective parents -- the kind who shield their kids from real danger -- never go out of style. A parent's top job is to protect his children. Schilling did his job. Finally, it's worth reminding everyone that freedom of expression does not mean freedom from rules, standards, and expectations that should guide your behavior. There are things you don't say. There are boundaries, ways that we expect you to behave so you don't terrorize other people or bring shame upon yourself, your friends, and
+ Label: Ruben Navarrette: Schilling deserves praise for taking on online haters for offensive comments about his daughter. Navarrette: In protecting his child, Schilling set a model for parenting and taught us a lesson about social media.
+
+Top Influential Example:
+ Input: summarize: (CNN) -- What is it with juries in high-profile cases in Southern California? Over the years, they've become a national joke. But no one is laughing. Instead, with each travesty of justice and every acquittal that should have been a conviction, you're left wondering just what trial these 12 folks were watching and questioning whether we should have higher standards for who sits on a jury. Sometimes, the juries in local and state courts get it wrong, and the Justice Department must step in and make it right. Think back to the acquittal in April 1992 of four Los Angeles police officers who, one year earlier, savagely beat motorist Rodney King. They walked out of a courtroom in Simi Valley, California, as free men -- sparking days of rioting, looting and violence. At the time, the conventional thinking on newspaper editorial pages and on talk radio was the jurors in that largely white suburb of Los Angeles, which was itself home to many active-duty and retired police officers, saw the police force as their line of defense against undesirables like King. So naturally, the argument went, they would cut them some slack. The officers were tried again, and convicted in federal court of violating King's civil rights. Justice was finally served. Here we go again. There hasn't been much civil unrest over what happened to Kelly Thomas, the homeless and mentally ill man who -- on July 5, 2011 -- was beaten to death by a swarm of police officers in Fullerton, California. But now that the verdict is in, literally, on the two former officers who were charged in his death, there is plenty of outrage on talk radio, online and in other public forums. Another 12 people who swore an oath to consider the evidence and the law and make sure that justice is served appear to have gotten it terribly wrong. This week, that jury in Santa Ana, California -- a city about 30 miles southeast of Los Angeles -- produced a wave of gasps in the courtroom when it announced that it had found Manuel Ramos, who had been charged with second-degree murder and involuntary manslaughter, and Jay Cicinelli, who was charged with involuntary manslaughter and excessive use of force, not guilty on all counts. What? The beating was caught on a surveillance tape. When you watch those 33 minutes of footage, assuming you can stomach the experience, it's hard to believe that anyone could declare the perpetrators "not guilty." The surveillance camera footage shows Thomas being beaten and stunned with a Taser by police until he was unrecognizable and unconscious. You see a defenseless and compliant young man screaming in pain, saying he's sorry and pleading for help from his father. His words will haunt you, "Daddy, help! They're killing me!" According to prosecutors, the young man suffered brain injuries, facial fractures, broken ribs and extensive bruises and abrasions. He wound up lying in a pool of blood. He died five days later. This was not a by-the-book case of police officers using all necessary force to subdue a suspect who was resisting arrest -- a suspect, by the way, who had committed no crime. This was not, as Ramos' attorney claimed, a case of police offices simply "doing their job" with "no malice in their heart." Check the video. Early on in the confrontation, Ramos appears to tell the young man who is sitting on the ground: "You see my fists? They're getting ready to f--- you up!" Another officer is heard telling a comrade: "We ran out of options so I got to the end of my Taser and I... smashed his face to hell." There is the malice. This was abuse of power and an instance of bullying behind a badge. It happens more than we'd like to think in America. But this time, it went too far. And a man died, and a family was shattered. Yet, the jury somehow missed all this? How does this happen? In Los Angeles, people are saying that the mentally ill are the new Rodney King. In the same way that the jury in Simi Valley was inclined to back the officers who it saw as protecting them from people like King, now the jury in Santa Ana is backing the officers who it counts on to prod people like Thomas to move along, leave the streets, and get out of sight. It's a plausible explanation
+ Label: Ruben Navarrette: Too many high-profile cases in California produce travesties of justice. He says jury acquitted two ex-cops in malicious beating death captured on video. He says case showed abuse of power, bullying behind badge; happens too often in U.S. Navarrette: Only one place left that can right this wrong: The Justice Department.
+```
\ No newline at end of file
diff --git a/examples/dailymail/__init__.py b/examples/dailymail/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/dailymail/analyze.py b/examples/dailymail/analyze.py
new file mode 100644
index 0000000..032cc08
--- /dev/null
+++ b/examples/dailymail/analyze.py
@@ -0,0 +1,275 @@
+import argparse
+import logging
+import os
+from typing import Dict, List
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import DataCollatorForSeq2Seq
+
+from examples.dailymail.pipeline import (
+ construct_t5,
+ get_dailymail_dataset,
+ get_tokenizer,
+)
+from kronfluence.analyzer import Analyzer, prepare_model
+from kronfluence.arguments import FactorArguments, ScoreArguments
+from kronfluence.task import Task
+from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
+from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
+from kronfluence.utils.dataset import DataLoaderKwargs
+from kronfluence.utils.model import apply_ddp
+
+BATCH_TYPE = Dict[str, torch.Tensor]
+try:
+ LOCAL_RANK = int(os.environ["LOCAL_RANK"])
+ WORLD_RANK = int(os.environ["RANK"])
+ WORLD_SIZE = int(os.environ["WORLD_SIZE"])
+except KeyError:
+ LOCAL_RANK = WORLD_RANK = WORLD_SIZE = 0
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Influence analysis on CNN/DailyMail dataset.")
+
+ parser.add_argument(
+ "--checkpoint_dir",
+ type=str,
+ default="./checkpoints",
+ help="A path that is storing the final checkpoint of the model.",
+ )
+
+ parser.add_argument(
+ "--factor_strategy",
+ type=str,
+ default="ekfac",
+ help="Strategy to compute influence factors.",
+ )
+ parser.add_argument(
+ "--query_gradient_rank",
+ type=int,
+ default=-1,
+ help="Rank for the low-rank query gradient approximation.",
+ )
+ parser.add_argument(
+ "--use_half_precision",
+ action="store_true",
+ default=False,
+ help="Whether to use half precision for computing factors and scores.",
+ )
+ parser.add_argument(
+ "--use_ddp",
+ action="store_true",
+ default=False,
+ help="Whether to use DDP for computing factors and scores.",
+ )
+ parser.add_argument(
+ "--factor_batch_size",
+ type=int,
+ default=64,
+ help="Batch size for computing influence factors.",
+ )
+ parser.add_argument(
+ "--query_batch_size",
+ type=int,
+ default=10,
+ help="Batch size for computing query gradients.",
+ )
+ parser.add_argument(
+ "--train_batch_size",
+ type=int,
+ default=128,
+ help="Batch size for computing training gradients.",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ default=False,
+ help="Boolean flag to profile computations.",
+ )
+ args = parser.parse_args()
+
+ if args.checkpoint_dir is not None:
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
+
+ return args
+
+
+class SummarizationTask(Task):
+ def compute_train_loss(
+ self,
+ batch: BATCH_TYPE,
+ model: nn.Module,
+ sample: bool = False,
+ ) -> torch.Tensor:
+ logits = model(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ decoder_input_ids=batch["decoder_input_ids"],
+ ).logits
+
+ if not sample:
+ return F.cross_entropy(
+ logits.view(-1, logits.size(-1)), batch["labels"].view(-1), ignore_index=-100, reduction="sum"
+ )
+ with torch.no_grad():
+ probs = torch.nn.functional.softmax(logits.view(-1, logits.size(-1)).detach(), dim=-1)
+ sampled_labels = torch.multinomial(
+ probs,
+ num_samples=1,
+ ).flatten()
+ masks = batch["labels"].view(-1) == -100
+ sampled_labels[masks] = -100
+ return F.cross_entropy(logits.view(-1, logits.size(-1)), sampled_labels, reduction="sum")
+
+ def compute_measurement(
+ self,
+ batch: BATCH_TYPE,
+ model: nn.Module,
+ ) -> torch.Tensor:
+ # Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py.
+ logits = model(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ decoder_input_ids=batch["decoder_input_ids"],
+ ).logits
+ logits = logits.view(-1, logits.size(-1))
+
+ labels = batch["labels"].view(-1)
+ bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False)
+ logits_correct = logits[bindex, labels]
+
+ cloned_logits = logits.clone()
+ cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)
+
+ margins = logits_correct - cloned_logits.logsumexp(dim=-1)
+ masks = batch["labels"].view(-1) != -100
+ return -margins[masks].sum()
+
+ def get_influence_tracked_modules(self) -> List[str]:
+ total_modules = []
+
+ # Add attention layers:
+ for i in range(6):
+ total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.q")
+ total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.k")
+ total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.v")
+ total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.o")
+
+ total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.q")
+ total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.k")
+ total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.v")
+ total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.o")
+
+ total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.q")
+ total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.k")
+ total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.v")
+ total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.o")
+
+ # Add MLP layers:
+ for i in range(6):
+ total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wi")
+ total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wo")
+
+ total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wi")
+ total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wo")
+
+ return total_modules
+
+ def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
+ return batch["attention_mask"]
+
+
+def main():
+ args = parse_args()
+ logging.basicConfig(level=logging.INFO)
+
+ # Prepare the dataset.
+ train_dataset = get_dailymail_dataset(
+ split="eval_train",
+ )
+ eval_dataset = get_dailymail_dataset(
+ split="valid",
+ )
+ tokenizer = get_tokenizer()
+
+ # Prepare the trained model.
+ model = construct_t5()
+
+ # Define task and prepare model.
+ task = SummarizationTask()
+ model = prepare_model(model, task)
+
+ if args.use_ddp:
+ model = apply_ddp(
+ model=model,
+ local_rank=LOCAL_RANK,
+ rank=WORLD_RANK,
+ world_size=WORLD_SIZE,
+ )
+
+ analyzer = Analyzer(
+ analysis_name="dailymail",
+ model=model,
+ task=task,
+ profile=args.profile,
+ )
+ # Configure parameters for DataLoader.
+ label_pad_token_id = -100
+ data_collator = DataCollatorForSeq2Seq(
+ tokenizer,
+ model=model,
+ label_pad_token_id=label_pad_token_id,
+ pad_to_multiple_of=None,
+ )
+
+ dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=data_collator)
+ analyzer.set_dataloader_kwargs(dataloader_kwargs)
+
+ # Compute influence factors.
+ factors_name = args.factor_strategy
+ factor_args = FactorArguments(strategy=args.factor_strategy)
+ if args.use_half_precision:
+ factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16)
+ factors_name += "_half"
+ if args.use_ddp:
+ factors_name += "_ddp"
+ analyzer.fit_all_factors(
+ factors_name=factors_name,
+ dataset=train_dataset,
+ per_device_batch_size=args.factor_batch_size,
+ factor_args=factor_args,
+ overwrite_output_dir=False,
+ )
+
+ # Compute pairwise scores.
+ score_args = ScoreArguments()
+ scores_name = factor_args.strategy
+ if args.use_half_precision:
+ score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
+ scores_name += "_half"
+ rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
+ if rank is not None:
+ score_args.query_gradient_low_rank = rank
+ score_args.query_gradient_accumulation_steps = 10
+ scores_name += f"_qlr{rank}"
+ if args.use_ddp:
+ scores_name += "_ddp"
+ analyzer.compute_pairwise_scores(
+ score_args=score_args,
+ scores_name=scores_name,
+ factors_name=factors_name,
+ query_dataset=eval_dataset,
+ query_indices=list(range(10)),
+ train_dataset=train_dataset,
+ per_device_query_batch_size=args.query_batch_size,
+ per_device_train_batch_size=args.train_batch_size,
+ overwrite_output_dir=False,
+ )
+ scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
+ logging.info(f"Scores shape: {scores.shape}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/dailymail/inspect_examples.py b/examples/dailymail/inspect_examples.py
new file mode 100644
index 0000000..2059b5c
--- /dev/null
+++ b/examples/dailymail/inspect_examples.py
@@ -0,0 +1,37 @@
+import logging
+
+import torch
+
+from examples.dailymail.pipeline import get_dailymail_dataset, get_tokenizer
+from kronfluence.analyzer import Analyzer
+
+
+def main():
+ logging.basicConfig(level=logging.INFO)
+
+ # You might need to change the path.
+ strategy = "ekfac"
+ scores = Analyzer.load_file(f"influence_results/dailymail/scores_{strategy}_half/pairwise_scores.safetensors")[
+ "all_modules"
+ ].to(dtype=torch.float32)
+
+ eval_idx = 1
+ train_dataset = get_dailymail_dataset(
+ split="eval_train",
+ )
+ eval_dataset = get_dailymail_dataset(
+ split="valid",
+ )
+ tokenizer = get_tokenizer()
+ print("Query Data Example:")
+ print(f"Input: {tokenizer.decode(eval_dataset[eval_idx]['input_ids'])}")
+ print(f"Label: {tokenizer.decode(eval_dataset[eval_idx]['labels'])}")
+
+ top_idx = int(torch.argsort(scores[eval_idx], descending=True)[0])
+ print("Top Influential Example:")
+ print(f"Input: {tokenizer.decode(train_dataset[top_idx]['input_ids'])}")
+ print(f"Label: {tokenizer.decode(train_dataset[top_idx]['labels'])}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/dailymail/pipeline.py b/examples/dailymail/pipeline.py
new file mode 100644
index 0000000..f221cdb
--- /dev/null
+++ b/examples/dailymail/pipeline.py
@@ -0,0 +1,110 @@
+from typing import Any, List
+
+import torch.nn as nn
+from datasets import Dataset, load_dataset
+from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
+
+summarization_name_mapping = {
+ "amazon_reviews_multi": ("review_body", "review_title"),
+ "big_patent": ("description", "abstract"),
+ "cnn_dailymail": ("article", "highlights"),
+ "orange_sum": ("text", "summary"),
+ "pn_summary": ("article", "summary"),
+ "psc": ("extract_text", "summary_text"),
+ "samsum": ("dialogue", "summary"),
+ "thaisum": ("body", "summary"),
+ "xglue": ("news_body", "news_title"),
+ "xsum": ("document", "summary"),
+ "wiki_summary": ("article", "highlights"),
+ "multi_news": ("document", "summary"),
+}
+
+
+MODEL_NAME = "google-t5/t5-small"
+
+
+def construct_t5() -> nn.Module:
+ config = AutoConfig.from_pretrained(
+ MODEL_NAME,
+ trust_remote_code=True,
+ )
+ return AutoModelForSeq2SeqLM.from_pretrained(
+ MODEL_NAME,
+ from_tf=False,
+ config=config,
+ ignore_mismatched_sizes=False,
+ trust_remote_code=True,
+ )
+
+
+def get_tokenizer() -> Any:
+ return AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
+
+
+def get_dailymail_dataset(
+ split: str,
+ indices: List[int] = None,
+) -> Dataset:
+ raw_datasets = load_dataset("cnn_dailymail", "3.0.0")
+
+ tokenizer = get_tokenizer()
+ column_names = raw_datasets["train"].column_names
+ dataset_columns = summarization_name_mapping.get("cnn_dailymail", None)
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+
+ max_source_length = 1024
+ max_target_length = 128
+ padding = False
+ prefix = "summarize: "
+
+ def preprocess_function(examples):
+ inputs = examples[text_column]
+ targets = examples[summary_column]
+ inputs = [prefix + inp for inp in inputs]
+ model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
+
+ labels = tokenizer(
+ text_target=targets,
+ max_length=max_target_length,
+ padding=padding,
+ truncation=True,
+ )
+
+ model_inputs["labels"] = labels["input_ids"]
+ return model_inputs
+
+ if split == "train" or split == "eval_train":
+ train_dataset = raw_datasets["train"]
+ train_dataset = train_dataset.map(
+ preprocess_function,
+ batched=True,
+ num_proc=None,
+ remove_columns=column_names,
+ load_from_cache_file=True,
+ desc="Running tokenizer on dataset.",
+ )
+ ds = train_dataset
+ else:
+ valid_dataset = raw_datasets["validation"]
+ eval_dataset = valid_dataset.map(
+ preprocess_function,
+ batched=True,
+ num_proc=None,
+ remove_columns=column_names,
+ load_from_cache_file=True,
+ desc="Running tokenizer on dataset.",
+ )
+ ds = eval_dataset
+
+ if indices is not None:
+ ds = ds.select(indices)
+
+ return ds
+
+
+if __name__ == "__main__":
+ from kronfluence import Analyzer
+
+ model = construct_t5()
+ print(Analyzer.get_module_summary(model))
diff --git a/examples/dailymail/requirements.txt b/examples/dailymail/requirements.txt
new file mode 100644
index 0000000..7ae6637
--- /dev/null
+++ b/examples/dailymail/requirements.txt
@@ -0,0 +1,7 @@
+sentencepiece!=0.1.92
+nltk
+py7zr
+rouge-score
+transformers
+evaluate
+datasets
diff --git a/examples/dailymail/train.py b/examples/dailymail/train.py
new file mode 100644
index 0000000..686b3d4
--- /dev/null
+++ b/examples/dailymail/train.py
@@ -0,0 +1,228 @@
+import argparse
+import logging
+import os
+import time
+from typing import Any, Dict
+
+import evaluate
+import nltk
+import numpy as np
+import torch
+import torch.nn.functional as F
+from accelerate.utils import send_to_device, set_seed
+from filelock import FileLock
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.utils import data
+from transformers import DataCollatorForSeq2Seq
+
+from examples.dailymail.pipeline import (
+ construct_t5,
+ get_dailymail_dataset,
+ get_tokenizer,
+)
+
+try:
+ nltk.data.find("tokenizers/punkt")
+except (LookupError, OSError):
+ with FileLock(".lock") as lock:
+ nltk.download("punkt", quiet=True)
+
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train seq2seq models on CNN/DailyMail dataset.")
+
+ parser.add_argument(
+ "--train_batch_size",
+ type=int,
+ default=16,
+ help="Batch size for the training dataloader.",
+ )
+ parser.add_argument(
+ "--eval_batch_size",
+ type=int,
+ default=32,
+ help="Batch size for the evaluation dataloader.",
+ )
+
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-05,
+ help="Fixed learning rate to train the model.",
+ )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.01,
+ help="Weight decay to train the model.",
+ )
+ parser.add_argument(
+ "--num_train_epochs",
+ type=int,
+ default=3,
+ help="Total number of epochs to train the model.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1004,
+ help="A seed for reproducible training pipeline.",
+ )
+ parser.add_argument(
+ "--checkpoint_dir",
+ type=str,
+ default="./checkpoints",
+ help="A path to store the final checkpoint.",
+ )
+ args = parser.parse_args()
+
+ if args.checkpoint_dir is not None:
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
+
+ return args
+
+
+def train(
+ dataset: data.Dataset,
+ tokenizer: Any,
+ batch_size: int,
+ num_train_epochs: int,
+ learning_rate: float,
+ weight_decay: float,
+) -> nn.Module:
+ model = construct_t5().to(DEVICE)
+ data_collator = DataCollatorForSeq2Seq(
+ tokenizer,
+ model=model,
+ label_pad_token_id=-100,
+ pad_to_multiple_of=None,
+ )
+ train_dataloader = data.DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ )
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
+
+ start_time = time.time()
+ model.train()
+ for epoch in range(num_train_epochs):
+ total_loss = 0.0
+ for batch in train_dataloader:
+ optimizer.zero_grad(set_to_none=True)
+ batch = send_to_device(batch, device=DEVICE)
+ loss = model(**batch).loss
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.detach().float()
+ logging.info(f"Epoch {epoch + 1} - Averaged Loss: {total_loss / len(dataset)}")
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ logging.info(f"Completed training in {elapsed_time:.2f} seconds.")
+ return model
+
+
+def evaluate_model(model: nn.Module, tokenizer: Any, dataset: data.Dataset, batch_size: int) -> Dict[str, Any]:
+ data_collator = DataCollatorForSeq2Seq(
+ tokenizer,
+ model=model,
+ label_pad_token_id=-100,
+ pad_to_multiple_of=None,
+ )
+ dataloader = data.DataLoader(
+ dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=data_collator
+ )
+ model.eval()
+
+ def postprocess_text(preds, labels):
+ preds = [pred.strip() for pred in preds]
+ labels = [label.strip() for label in labels]
+
+ # rougeLSum expects newline after each sentence.
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
+
+ return preds, labels
+
+ gen_kwargs = {
+ "max_length": 128,
+ }
+ metric = evaluate.load("rouge")
+ loss_fn = CrossEntropyLoss(ignore_index=-100, reduction="mean")
+ total_loss = 0.0
+ for step, batch in enumerate(dataloader):
+ with torch.no_grad():
+ logits = model(
+ input_ids=batch["input_ids"].to(device=DEVICE),
+ attention_mask=batch["attention_mask"].to(device=DEVICE),
+ decoder_input_ids=batch["decoder_input_ids"].to(device=DEVICE),
+ ).logits
+ labels = batch["labels"].to(device=DEVICE)
+ loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
+ total_loss += loss.detach().float().item()
+
+ labels = labels.cpu().numpy()
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
+ generated_tokens = model.generate(
+ batch["input_ids"].to(device=DEVICE),
+ attention_mask=batch["attention_mask"].to(device=DEVICE),
+ **gen_kwargs,
+ )
+ if isinstance(generated_tokens, tuple):
+ generated_tokens = generated_tokens[0]
+ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
+ metric.add_batch(
+ predictions=decoded_preds,
+ references=decoded_labels,
+ )
+
+ result = metric.compute(use_stemmer=True)
+ result = {k: round(v * 100, 4) for k, v in result.items()}
+ result["loss"] = total_loss / len(dataloader)
+ return result
+
+
+def main():
+ args = parse_args()
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ tokenizer = get_tokenizer()
+ train_dataset = get_dailymail_dataset(split="train")
+ model = train(
+ dataset=train_dataset,
+ tokenizer=tokenizer,
+ batch_size=args.train_batch_size,
+ num_train_epochs=args.num_train_epochs,
+ learning_rate=args.learning_rate,
+ weight_decay=args.weight_decay,
+ )
+
+ eval_train_dataset = get_dailymail_dataset(split="eval_train")
+ results = evaluate_model(
+ model=model, tokenizer=tokenizer, dataset=eval_train_dataset, batch_size=args.eval_batch_size
+ )
+ logger.info(f"Train evaluation results: {results}")
+
+ eval_dataset = get_dailymail_dataset(split="valid")
+ results = evaluate_model(model=model, tokenizer=tokenizer, dataset=eval_dataset, batch_size=args.eval_batch_size)
+ logger.info(f"Valid evaluation results: {results}")
+
+ if args.checkpoint_dir is not None:
+ torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/glue/README.md b/examples/glue/README.md
index 045eac4..c1b43a4 100644
--- a/examples/glue/README.md
+++ b/examples/glue/README.md
@@ -9,7 +9,7 @@ pip install -r requirements.txt
## Training
-To fine-tune BERT on a specific dataset, run the following command (we are using the `SST2` dataset in this example):
+To fine-tune BERT on a specific dataset, run the following command (we are using the SST2 dataset in this example):
```bash
python train.py --dataset_name sst2 \
@@ -36,25 +36,25 @@ python analyze.py --dataset_name sst2 \
--factor_strategy ekfac
```
-On an A100 (80GB), it takes roughly 95 minutes to compute the pairwise scores for `SST2` (including computing EKFAC factors):
+On an A100 (80GB), it takes roughly 90 minutes to compute the pairwise scores for SST2 (including computing EKFAC factors):
```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 11 | 5568.5 | 100 % |
+| Total | - | 11 | 5088.0 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
-| Compute Pairwise Score | 2668.0 | 1 | 2668.0 | 47.913 |
-| Fit Lambda | 2361.5 | 1 | 2361.5 | 42.408 |
-| Fit Covariance | 483.63 | 1 | 483.63 | 8.685 |
-| Perform Eigendecomposition | 26.307 | 1 | 26.307 | 0.47243 |
-| Save Covariance | 11.445 | 1 | 11.445 | 0.20552 |
-| Save Eigendecomposition | 10.959 | 1 | 10.959 | 0.1968 |
-| Save Lambda | 3.0458 | 1 | 3.0458 | 0.054696 |
-| Save Pairwise Score | 2.0978 | 1 | 2.0978 | 0.037671 |
-| Load Covariance | 0.72168 | 1 | 0.72168 | 0.01296 |
-| Load Eigendecomposition | 0.5194 | 1 | 0.5194 | 0.0093274 |
-| Load All Factors | 0.25427 | 1 | 0.25427 | 0.0045661 |
+| Fit Lambda | 2370.0 | 1 | 2370.0 | 46.581 |
+| Compute Pairwise Score | 2222.4 | 1 | 2222.4 | 43.679 |
+| Fit Covariance | 478.83 | 1 | 478.83 | 9.411 |
+| Perform Eigendecomposition | 10.587 | 1 | 10.587 | 0.20808 |
+| Save Eigendecomposition | 2.5419 | 1 | 2.5419 | 0.049958 |
+| Save Covariance | 2.3878 | 1 | 2.3878 | 0.046931 |
+| Save Lambda | 0.66905 | 1 | 0.66905 | 0.01315 |
+| Save Pairwise Score | 0.51374 | 1 | 0.51374 | 0.010097 |
+| Load All Factors | 0.01321 | 1 | 0.01321 | 0.00025963 |
+| Load Covariance | 0.0081149 | 1 | 0.0081149 | 0.00015949 |
+| Load Eigendecomposition | 0.0079874 | 1 | 0.0079874 | 0.00015699 |
----------------------------------------------------------------------------------------------------------------------------------
```
@@ -69,32 +69,32 @@ python analyze.py --dataset_name sst2 \
--use_half_precision
```
-This reduces computation time to about 30 minutes on an A100 (80GB) GPU.
+This reduces computation time to about 20 minutes on an A100 (80GB) GPU.
```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 11 | 1832.0 | 100 % |
+| Total | - | 11 | 1222.4 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
-| Compute Pairwise Score | 1143.2 | 1 | 1143.2 | 62.4 |
-| Fit Lambda | 555.97 | 1 | 555.97 | 30.348 |
-| Fit Covariance | 99.467 | 1 | 99.467 | 5.4294 |
-| Perform Eigendecomposition | 18.566 | 1 | 18.566 | 1.0134 |
-| Save Covariance | 5.4877 | 1 | 5.4877 | 0.29954 |
-| Save Eigendecomposition | 5.3713 | 1 | 5.3713 | 0.29319 |
-| Save Lambda | 1.5586 | 1 | 1.5586 | 0.085078 |
-| Save Pairwise Score | 1.0651 | 1 | 1.0651 | 0.05814 |
-| Load Eigendecomposition | 0.54052 | 1 | 0.54052 | 0.029504 |
-| Load Covariance | 0.53759 | 1 | 0.53759 | 0.029345 |
-| Load All Factors | 0.26048 | 1 | 0.26048 | 0.014218 |
+| Compute Pairwise Score | 582.08 | 1 | 582.08 | 47.617 |
+| Fit Lambda | 543.55 | 1 | 543.55 | 44.465 |
+| Fit Covariance | 83.877 | 1 | 83.877 | 6.8616 |
+| Perform Eigendecomposition | 9.4054 | 1 | 9.4054 | 0.76942 |
+| Save Eigendecomposition | 1.516 | 1 | 1.516 | 0.12401 |
+| Save Covariance | 1.434 | 1 | 1.434 | 0.11731 |
+| Save Lambda | 0.28022 | 1 | 0.28022 | 0.022924 |
+| Save Pairwise Score | 0.24123 | 1 | 0.24123 | 0.019734 |
+| Load All Factors | 0.01241 | 1 | 0.01241 | 0.0010152 |
+| Load Covariance | 0.0080553 | 1 | 0.0080553 | 0.00065897 |
+| Load Eigendecomposition | 0.0077278 | 1 | 0.0077278 | 0.00063218 |
----------------------------------------------------------------------------------------------------------------------------------
```
## Counterfactual Evaluation
-Evaluate the impact of removing top positively influential training examples on query misclassification.
-First, compute pairwise influence scores for the `RTE` dataset:
+Let's evaluate the impact of removing top positively influential training examples on query misclassification.
+First, compute pairwise influence scores for the `RTE` dataset (the below commands used a single A100 GPU):
```bash
python train.py --dataset_name rte \
diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py
index 6876984..90dd71f 100644
--- a/examples/glue/analyze.py
+++ b/examples/glue/analyze.py
@@ -1,7 +1,7 @@
import argparse
import logging
import os
-from typing import Dict, Optional
+from typing import Dict
import torch
import torch.nn.functional as F
@@ -126,7 +126,7 @@ def compute_measurement(
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()
- def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
+ def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
return batch["attention_mask"]
@@ -188,8 +188,8 @@ def main():
scores_name += "_half"
rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
if rank is not None:
- score_args.query_gradient_rank = rank
- score_args.num_query_gradient_accumulations = 10
+ score_args.query_gradient_low_rank = rank
+ score_args.query_gradient_accumulation_steps = 10
scores_name += f"_qlr{rank}"
analyzer.compute_pairwise_scores(
score_args=score_args,
diff --git a/examples/glue/half_precision_analysis.py b/examples/glue/half_precision_analysis.py
new file mode 100644
index 0000000..288c39c
--- /dev/null
+++ b/examples/glue/half_precision_analysis.py
@@ -0,0 +1,40 @@
+import logging
+
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.stats import spearmanr
+from tueplots import markers
+
+from kronfluence.analyzer import Analyzer
+
+
+def main():
+ logging.basicConfig(level=logging.INFO)
+
+ # Load the scores. You might need to modify the path.
+ scores = Analyzer.load_file("influence_results/sst2/scores_ekfac/pairwise_scores.safetensors")["all_modules"]
+ half_scores = Analyzer.load_file("influence_results/sst2/scores_ekfac_half/pairwise_scores.safetensors")[
+ "all_modules"
+ ].float()
+
+ plt.rcParams.update({"figure.dpi": 150})
+ plt.rcParams.update(markers.with_edge())
+ plt.rcParams["axes.axisbelow"] = True
+
+ # Only plot first 6000 points to avoid clutter.
+ idx = 79
+ plt.scatter(half_scores[idx][:6000], scores[idx][:6000], edgecolor="k")
+ plt.grid()
+ plt.xlabel("bfloat16")
+ plt.ylabel("float32")
+ plt.show()
+
+ # Compute the averaged spearman correlation.
+ all_corr = []
+ for i in range(500):
+ all_corr.append(spearmanr(scores[i], half_scores[i])[0])
+ logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md
index e959c19..6595bb5 100644
--- a/examples/imagenet/README.md
+++ b/examples/imagenet/README.md
@@ -17,7 +17,7 @@ To compute pairwise influence scores on 1000 query data points using the `ekfac`
python analyze.py --dataset_dir PATH_TO_IMAGENET \
--query_gradient_rank -1 \
--query_batch_size 100 \
- --train_batch_size 300 \
+ --train_batch_size 256 \
--factor_strategy ekfac
```
@@ -53,17 +53,17 @@ python analyze.py --dataset_dir PATH_TO_IMAGENET \
--factor_strategy ekfac
```
-On an A100 (80GB) GPU, it takes roughly 3.5 hours to compute the pairwise scores with query batching (including computing EKFAC factors):
+On an A100 (80GB) GPU, it takes roughly 3 hours to compute the pairwise scores with query batching:
```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 3 | 9896.8 | 100 % |
+| Total | - | 3 | 7352.8 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
-| Compute Pairwise Score | 9849.7 | 1 | 9849.7 | 99.524 |
-| Save Pairwise Score | 47.075 | 1 | 47.075 | 0.47566 |
-| Load All Factors | 0.014463 | 1 | 0.014463 | 0.00014614 |
+| Compute Pairwise Score | 7340.8 | 1 | 7340.8 | 99.836 |
+| Save Pairwise Score | 12.026 | 1 | 12.026 | 0.16355 |
+| Load All Factors | 0.0099941 | 1 | 0.0099941 | 0.00013592 |
----------------------------------------------------------------------------------------------------------------------------------
```
@@ -91,19 +91,19 @@ This reduces computation time to about 85 minutes on an A100 (80GB) GPU:
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
-| Total | - | 11 | 4926.3 | 100 % |
-----------------------------------------------------------------------------------------------------------------------------------
-| Compute Pairwise Score | 4446.0 | 1 | 4446.0 | 90.249 |
-| Fit Lambda | 255.45 | 1 | 255.45 | 5.1853 |
-| Fit Covariance | 186.86 | 1 | 186.86 | 3.7931 |
-| Save Pairwise Score | 23.205 | 1 | 23.205 | 0.47104 |
-| Perform Eigendecomposition | 7.1356 | 1 | 7.1356 | 0.14485 |
-| Save Eigendecomposition | 3.3045 | 1 | 3.3045 | 0.067079 |
-| Save Covariance | 2.993 | 1 | 2.993 | 0.060756 |
-| Save Lambda | 0.58278 | 1 | 0.58278 | 0.01183 |
-| Load Eigendecomposition | 0.39114 | 1 | 0.39114 | 0.0079398 |
-| Load Covariance | 0.27701 | 1 | 0.27701 | 0.005623 |
-| Load All Factors | 0.1699 | 1 | 0.1699 | 0.0034489 |
+| Total | - | 11 | 3023.7 | 100 % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Compute Pairwise Score | 2621.2 | 1 | 2621.2 | 86.688 |
+| Fit Lambda | 232.3 | 1 | 232.3 | 7.6825 |
+| Fit Covariance | 157.36 | 1 | 157.36 | 5.204 |
+| Perform Eigendecomposition | 5.745 | 1 | 5.745 | 0.19 |
+| Save Pairwise Score | 5.6676 | 1 | 5.6676 | 0.18744 |
+| Save Covariance | 0.70454 | 1 | 0.70454 | 0.0233 |
+| Save Eigendecomposition | 0.61539 | 1 | 0.61539 | 0.020352 |
+| Save Lambda | 0.092784 | 1 | 0.092784 | 0.0030685 |
+| Load Covariance | 0.013714 | 1 | 0.013714 | 0.00045354 |
+| Load All Factors | 0.0088742 | 1 | 0.0088742 | 0.00029348 |
+| Load Eigendecomposition | 0.0056237 | 1 | 0.0056237 | 0.00018599 |
----------------------------------------------------------------------------------------------------------------------------------
```
diff --git a/examples/imagenet/analyze.py b/examples/imagenet/analyze.py
index d77de2e..9dd14b1 100644
--- a/examples/imagenet/analyze.py
+++ b/examples/imagenet/analyze.py
@@ -101,6 +101,7 @@ def main():
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=False,
+ initial_per_device_batch_size_attempt=512,
)
# Compute pairwise scores.
@@ -111,8 +112,8 @@ def main():
scores_name += "_half"
rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
if rank is not None:
- score_args.query_gradient_rank = rank
- score_args.num_query_gradient_accumulations = 10
+ score_args.query_gradient_low_rank = rank
+ score_args.query_gradient_accumulation_steps = 10
scores_name += f"_qlr{rank}"
analyzer.compute_pairwise_scores(
scores_name=scores_name,
diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py
index f3c9733..33f9653 100644
--- a/examples/imagenet/ddp_analyze.py
+++ b/examples/imagenet/ddp_analyze.py
@@ -1,7 +1,6 @@
import argparse
import logging
import os
-from typing import Tuple
import torch
@@ -133,8 +132,8 @@ def main():
scores_name += "_half"
rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
if rank is not None:
- score_args.query_gradient_rank = rank
- score_args.num_query_gradient_accumulations = 10
+ score_args.query_gradient_low_rank = rank
+ score_args.query_gradient_accumulation_steps = 10
scores_name += f"_qlr{rank}"
analyzer.compute_pairwise_scores(
scores_name=scores_name,
diff --git a/examples/imagenet/query_batching_analysis.py b/examples/imagenet/query_batching_analysis.py
index ae3e92b..6e82e98 100644
--- a/examples/imagenet/query_batching_analysis.py
+++ b/examples/imagenet/query_batching_analysis.py
@@ -2,7 +2,7 @@
import matplotlib.pyplot as plt
import numpy as np
-from scipy.stats import spearmanr
+from scipy.stats import pearsonr, spearmanr
from tueplots import markers
from kronfluence.analyzer import Analyzer
@@ -18,12 +18,15 @@ def main():
lr_scores = Analyzer.load_file("influence_results/imagenet/scores_ekfac_qlr32/pairwise_scores.safetensors")[
"all_modules"
]
+ # lr_scores = Analyzer.load_file("influence_results/imagenet/scores_ekfac_half_qlr32/pairwise_scores.safetensors")[
+ # "all_modules"
+ # ].float()
- # Only plot first 1000 points to avoid clutter.
+ # Only plot first 5000 points to avoid clutter.
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams.update(markers.with_edge())
plt.rcParams["axes.axisbelow"] = True
- plt.scatter(lr_scores[0][:1000], full_scores[0][:1000], edgecolor="k")
+ plt.scatter(lr_scores[0][:5000], full_scores[0][:5000], edgecolor="k")
plt.grid()
plt.xlabel("Full Rank Score")
plt.ylabel("Low Rank (32) Score")
@@ -35,6 +38,12 @@ def main():
all_corr.append(spearmanr(full_scores[i], lr_scores[i])[0])
logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}")
+ # Compute the averaged pearson correlation.
+ all_corr = []
+ for i in range(100):
+ all_corr.append(pearsonr(full_scores[i], lr_scores[i])[0])
+ logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}")
+
if __name__ == "__main__":
main()
diff --git a/examples/openwebtext/README.md b/examples/openwebtext/README.md
index aeca298..f10ece0 100644
--- a/examples/openwebtext/README.md
+++ b/examples/openwebtext/README.md
@@ -1,6 +1,29 @@
+# OpenWebText & Llama-3-8B Example
+
+This repository contains scripts for computing influence scores on the subset of OpenWebText dataset. The pipeline is motivated from [LoggIX repository](https://github.com/logix-project/logix/tree/main/examples/language_modeling).
+Install the necessary packages:
+
```bash
-python analyze.py --query_batch_size 32 \
- --train_batch_size 64 \
- --checkpoint_dir ./checkpoints \
- --factor_strategy ekfac
+pip install -r requirements.txt
+```
+
+We will use the pre-trained Meta-Llama-3-8B model [from HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B).
+
+## Computing EKFAC Factors
+
+To compute factors using the `ekfac` factorization strategy, run the following command which uses 4 A100 (80GB) GPUs:
+
+```bash
+torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factors_name jul_11_2024 --factor_batch_size 4
+```
+
+## Computing Influence Scores
+
+The `generate.py` folder contains a code to generate response of the Llama-3-8B model given certain prompt.
+I saved some prompt and completition pair to the directory `data/data.json`.
+
+To compute influence scores on the generated prompt and compleition pair, run the following command:
+
+```bash
+torchrun --standalone --nnodes=1 --nproc-per-node=4 compute_scores.py --train_batch_size 8 --query_gradient_rank 32
```
\ No newline at end of file
diff --git a/examples/openwebtext/analyze.py b/examples/openwebtext/analyze.py
deleted file mode 100644
index 230518c..0000000
--- a/examples/openwebtext/analyze.py
+++ /dev/null
@@ -1,217 +0,0 @@
-import argparse
-import logging
-import os
-from typing import Dict, List, Optional
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from transformers import default_data_collator
-
-from examples.openwebtext.pipeline import (
- construct_llama3,
- get_custom_dataset,
- get_openwebtext_dataset,
-)
-from examples.wikitext.pipeline import construct_gpt2, get_wikitext_dataset
-from kronfluence.analyzer import Analyzer, prepare_model
-from kronfluence.arguments import FactorArguments, ScoreArguments
-from kronfluence.task import Task
-from kronfluence.utils.common.factor_arguments import (
- all_low_precision_factor_arguments,
- extreme_reduce_memory_factor_arguments,
-)
-from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
-from kronfluence.utils.dataset import DataLoaderKwargs
-
-BATCH_TYPE = Dict[str, torch.Tensor]
-
-
-if torch.cuda.is_available():
- torch.backends.cuda.matmul.allow_tf32 = True
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Influence analysis on WikiText dataset.")
-
- parser.add_argument(
- "--checkpoint_dir",
- type=str,
- default="./checkpoints",
- help="A path that is storing the final checkpoint of the model.",
- )
-
- parser.add_argument(
- "--factor_strategy",
- type=str,
- default="ekfac",
- help="Strategy to compute influence factors.",
- )
- parser.add_argument(
- "--query_gradient_rank",
- type=int,
- default=-1,
- help="Rank for the low-rank query gradient approximation.",
- )
- parser.add_argument(
- "--use_half_precision",
- action="store_true",
- default=False,
- help="Whether to use half precision for computing factors and scores.",
- )
- parser.add_argument(
- "--use_compile",
- action="store_true",
- default=False,
- help="Whether to use torch compile for computing factors and scores.",
- )
- parser.add_argument(
- "--query_batch_size",
- type=int,
- default=1,
- help="Batch size for computing query gradients.",
- )
- parser.add_argument(
- "--train_batch_size",
- type=int,
- default=8,
- help="Batch size for computing query gradients.",
- )
- parser.add_argument(
- "--profile",
- action="store_true",
- default=False,
- help="Boolean flag to profile computations.",
- )
- args = parser.parse_args()
-
- if args.checkpoint_dir is not None:
- os.makedirs(args.checkpoint_dir, exist_ok=True)
-
- return args
-
-
-class LanguageModelingTask(Task):
- def compute_train_loss(
- self,
- batch: BATCH_TYPE,
- model: nn.Module,
- sample: bool = False,
- ) -> torch.Tensor:
- logits = model(
- input_ids=batch["input_ids"],
- attention_mask=batch["attention_mask"],
- ).logits
- shift_logits = logits[..., :-1, :].contiguous()
-
- if not sample:
- labels = batch["labels"]
- shift_labels = labels[..., 1:].contiguous()
- reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
- summed_loss = F.cross_entropy(
- reshaped_shift_logits, shift_labels.view(-1), reduction="sum", ignore_index=-100
- )
- else:
- reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
- with torch.no_grad():
- probs = torch.nn.functional.softmax(reshaped_shift_logits.detach(), dim=-1)
- sampled_labels = torch.multinomial(
- probs,
- num_samples=1,
- ).flatten()
- summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, reduction="sum")
- return summed_loss
-
- def compute_measurement(
- self,
- batch: BATCH_TYPE,
- model: nn.Module,
- ) -> torch.Tensor:
- logits = model(
- input_ids=batch["input_ids"],
- attention_mask=batch["attention_mask"],
- ).logits
- shift_logits = logits[..., :-1, :].contiguous()
- labels = batch["labels"]
- shift_labels = labels[..., 1:].contiguous().view(-1)
- reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
- return F.cross_entropy(reshaped_shift_logits, shift_labels, reduction="sum", ignore_index=-100)
-
- def tracked_modules(self) -> List[str]:
- total_modules = []
-
- for i in range(32):
- # Only uses the MLP modules.
- total_modules.append(f"model.layers.{i}.mlp.gate_proj")
- total_modules.append(f"model.layers.{i}.mlp.up_proj")
- total_modules.append(f"model.layers.{i}.mlp.down_proj")
-
- return total_modules
-
- def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
- return batch["attention_mask"]
-
-
-def main():
- args = parse_args()
- logging.basicConfig(level=logging.INFO)
-
- # Prepare the dataset.
- train_dataset = get_openwebtext_dataset()
- eval_dataset = get_custom_dataset()
-
- # Prepare the trained model.
- model = construct_llama3().to(dtype=torch.bfloat16)
-
- # Define task and prepare model.
- task = LanguageModelingTask()
- model = prepare_model(model, task)
-
- analyzer = Analyzer(
- analysis_name="openwebtext",
- model=model,
- task=task,
- profile=args.profile,
- )
- # Configure parameters for DataLoader.
- dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator)
- analyzer.set_dataloader_kwargs(dataloader_kwargs)
-
- # Compute influence factors.
- factors_name = args.factor_strategy
- factor_args = extreme_reduce_memory_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16)
- analyzer.fit_all_factors(
- factors_name=factors_name,
- dataset=train_dataset,
- per_device_batch_size=None,
- factor_args=factor_args,
- overwrite_output_dir=False,
- initial_per_device_batch_size_attempt=64,
- )
-
- # Compute pairwise scores.
- scores_name = factor_args.strategy
- score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
-
- rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
- score_args.num_query_gradient_accumulations = 10
- if rank is not None:
- score_args.query_gradient_rank = rank
- scores_name += f"_qlr{rank}"
- analyzer.compute_pairwise_scores(
- scores_name=scores_name,
- score_args=score_args,
- factors_name=factors_name,
- query_dataset=eval_dataset,
- query_indices=list(range(min([len(eval_dataset), 2000]))),
- train_dataset=train_dataset,
- per_device_query_batch_size=args.query_batch_size,
- per_device_train_batch_size=args.train_batch_size,
- overwrite_output_dir=False,
- )
- scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
- logging.info(f"Scores shape: {scores.shape}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/openwebtext/compute_scores.py b/examples/openwebtext/compute_scores.py
new file mode 100644
index 0000000..61622d3
--- /dev/null
+++ b/examples/openwebtext/compute_scores.py
@@ -0,0 +1,116 @@
+import argparse
+import logging
+from datetime import timedelta
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator, InitProcessGroupKwargs
+from torch import nn
+from transformers import default_data_collator
+
+from examples.openwebtext.pipeline import (
+ construct_llama3,
+ get_custom_dataset,
+ get_openwebtext_dataset,
+)
+from examples.openwebtext.task import LanguageModelingTask
+from kronfluence.analyzer import Analyzer, prepare_model
+from kronfluence.task import Task
+from kronfluence.utils.common.factor_arguments import (
+ extreme_reduce_memory_factor_arguments,
+)
+from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments, \
+ extreme_reduce_memory_score_arguments
+from kronfluence.utils.dataset import DataLoaderKwargs
+
+BATCH_TYPE = Dict[str, torch.Tensor]
+
+torch.backends.cudnn.benchmark = True
+torch.backends.cuda.matmul.allow_tf32 = True
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Influence score computation on Openwebtext dataset.")
+
+ parser.add_argument(
+ "--factor_strategy",
+ type=str,
+ default="ekfac",
+ help="Strategy to compute influence factors.",
+ )
+ parser.add_argument(
+ "--query_gradient_rank",
+ type=int,
+ default=-1,
+ help="Rank for the low-rank query gradient approximation.",
+ )
+ parser.add_argument(
+ "--train_batch_size",
+ type=int,
+ default=4,
+ help="Batch size for computing query gradients.",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ default=False,
+ help="Boolean flag to profile computations.",
+ )
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging.basicConfig(level=logging.INFO)
+
+ # Prepare the dataset.
+ train_dataset = get_openwebtext_dataset()
+ eval_dataset = get_custom_dataset()
+
+ # Prepare the trained model.
+ model = construct_llama3()
+
+ # Define task and prepare model.
+ task = LanguageModelingTask()
+ model = prepare_model(model, task)
+
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) # 1.5 hours.
+ accelerator = Accelerator(kwargs_handlers=[kwargs])
+ model = accelerator.prepare_model(model)
+
+ analyzer = Analyzer(
+ analysis_name="openwebtext",
+ model=model,
+ task=task,
+ profile=args.profile,
+ )
+ # Configure parameters for DataLoader.
+ dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True)
+ analyzer.set_dataloader_kwargs(dataloader_kwargs)
+
+ scores_name = args.factor_strategy
+ rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
+ score_args = extreme_reduce_memory_score_arguments(
+ damping_factor=None, module_partitions=1, query_gradient_low_rank=rank, dtype=torch.bfloat16
+ )
+ # score_args.module_partitions = 2
+ score_args.query_gradient_accumulation_steps = 10
+ analyzer.compute_pairwise_scores(
+ scores_name=scores_name,
+ score_args=score_args,
+ factors_name=args.factor_strategy,
+ query_dataset=eval_dataset,
+ train_dataset=train_dataset,
+ per_device_query_batch_size=1,
+ per_device_train_batch_size=args.train_batch_size,
+ overwrite_output_dir=False,
+ )
+ scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
+ logging.info(f"Scores shape: {scores.shape}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/openwebtext/fit_factors.py b/examples/openwebtext/fit_factors.py
new file mode 100644
index 0000000..740f349
--- /dev/null
+++ b/examples/openwebtext/fit_factors.py
@@ -0,0 +1,102 @@
+import argparse
+import logging
+from datetime import timedelta
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator, InitProcessGroupKwargs
+from torch import nn
+from transformers import default_data_collator
+
+from examples.openwebtext.pipeline import construct_llama3, get_openwebtext_dataset
+from examples.openwebtext.task import LanguageModelingTask
+from kronfluence.analyzer import Analyzer, prepare_model
+from kronfluence.utils.common.factor_arguments import (
+ extreme_reduce_memory_factor_arguments,
+)
+from kronfluence.utils.dataset import DataLoaderKwargs
+
+BATCH_TYPE = Dict[str, torch.Tensor]
+
+torch.backends.cudnn.benchmark = True
+torch.backends.cuda.matmul.allow_tf32 = True
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Influence factor computation on Openwebtext dataset.")
+
+ parser.add_argument(
+ "--factors_name",
+ type=str,
+ default="july_11",
+ help="Strategy to compute influence factors.",
+ )
+ parser.add_argument(
+ "--factor_strategy",
+ type=str,
+ default="ekfac",
+ help="Strategy to compute influence factors.",
+ )
+ parser.add_argument(
+ "--factor_batch_size",
+ type=int,
+ default=4,
+ help="Batch size for computing influence factors.",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ default=False,
+ help="Boolean flag to profile computations.",
+ )
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging.basicConfig(level=logging.INFO)
+
+ # Prepare the dataset.
+ train_dataset = get_openwebtext_dataset()
+
+ # Prepare the trained model.
+ model = construct_llama3()
+
+ # Define task and prepare model.
+ task = LanguageModelingTask()
+ model = prepare_model(model, task)
+
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) # 1.5 hours.
+ accelerator = Accelerator(kwargs_handlers=[kwargs])
+ model = accelerator.prepare_model(model)
+
+ analyzer = Analyzer(
+ analysis_name="openwebtext",
+ model=model,
+ task=task,
+ profile=args.profile,
+ )
+ # Configure parameters for DataLoader.
+ dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True)
+ analyzer.set_dataloader_kwargs(dataloader_kwargs)
+
+ factors_name = args.factors_name
+ factor_args = extreme_reduce_memory_factor_arguments(
+ strategy=args.factor_strategy, module_partitions=1, dtype=torch.bfloat16
+ )
+ factor_args.covariance_module_partitions = 2
+ factor_args.lambda_module_partitions = 4
+ analyzer.fit_all_factors(
+ factors_name=factors_name,
+ dataset=train_dataset,
+ per_device_batch_size=args.factor_batch_size,
+ factor_args=factor_args,
+ overwrite_output_dir=False,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/openwebtext/pipeline.py b/examples/openwebtext/pipeline.py
index 1c4589a..fda0fe0 100644
--- a/examples/openwebtext/pipeline.py
+++ b/examples/openwebtext/pipeline.py
@@ -1,19 +1,22 @@
import copy
from typing import List
+import torch
from datasets import load_dataset
from torch import nn
from torch.utils import data
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
+MAX_LENGTH = 512
+
def construct_llama3() -> nn.Module:
config = AutoConfig.from_pretrained(
- "meta-llama/Meta-Llama-3-8B",
- trust_remote_code=True,
+ MODEL_NAME, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"
)
model = AutoModelForCausalLM.from_pretrained(
- "meta-llama/Meta-Llama-3-8B",
+ MODEL_NAME,
from_tf=False,
config=config,
ignore_mismatched_sizes=False,
@@ -25,17 +28,19 @@ def construct_llama3() -> nn.Module:
def get_openwebtext_dataset(
indices: List[int] = None,
) -> data.Dataset:
- raw_datasets = load_dataset("stas/openwebtext-10k")
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True, trust_remote_code=True)
-
+ raw_datasets = load_dataset("Elriggs/openwebtext-100k")
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
- results = tokenizer(examples[text_column_name], truncation=True, padding=True, max_length=64)
+ results = tokenizer(examples[text_column_name], truncation=True, padding=True, max_length=MAX_LENGTH)
results["labels"] = results["input_ids"].copy()
+ results["labels"] = [
+ [-100 if token == tokenizer.pad_token_id else token for token in label] for label in results["labels"]
+ ]
return results
tokenized_datasets = raw_datasets.map(
@@ -47,8 +52,7 @@ def tokenize_function(examples):
desc="Running tokenizer on dataset",
)
- train_dataset = tokenized_datasets["train"]
- ds = train_dataset
+ ds = tokenized_datasets["train"]
if indices is not None:
ds = ds.select(indices)
@@ -65,7 +69,7 @@ def get_custom_dataset(
"num_proc": 4,
}
raw_datasets = load_dataset(**data_kwargs)["train"]
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
def tokenize_function(examples):
data_dict = {}
diff --git a/examples/openwebtext/requirements.txt b/examples/openwebtext/requirements.txt
new file mode 100644
index 0000000..462542f
--- /dev/null
+++ b/examples/openwebtext/requirements.txt
@@ -0,0 +1,2 @@
+transformers
+datasets
diff --git a/examples/openwebtext/task.py b/examples/openwebtext/task.py
new file mode 100644
index 0000000..343dba2
--- /dev/null
+++ b/examples/openwebtext/task.py
@@ -0,0 +1,71 @@
+from typing import Dict, List
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from kronfluence.task import Task
+
+BATCH_TYPE = Dict[str, torch.Tensor]
+
+
+class LanguageModelingTask(Task):
+ def compute_train_loss(
+ self,
+ batch: BATCH_TYPE,
+ model: nn.Module,
+ sample: bool = False,
+ ) -> torch.Tensor:
+ logits = model(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ ).logits
+ logits = logits[..., :-1, :].contiguous()
+ logits = logits.view(-1, logits.size(-1))
+
+ if not sample:
+ labels = batch["labels"]
+ shift_labels = labels[..., 1:].contiguous()
+ summed_loss = F.cross_entropy(logits, shift_labels.view(-1), reduction="sum", ignore_index=-100)
+ else:
+ with torch.no_grad():
+ probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
+ sampled_labels = torch.multinomial(
+ probs,
+ num_samples=1,
+ ).flatten()
+ summed_loss = F.cross_entropy(logits, sampled_labels, ignore_index=-100, reduction="sum")
+ return summed_loss
+
+ def compute_measurement(
+ self,
+ batch: BATCH_TYPE,
+ model: nn.Module,
+ ) -> torch.Tensor:
+ logits = model(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ ).logits
+ shift_labels = batch["labels"][..., 1:].contiguous().view(-1)
+ logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1))
+ return F.cross_entropy(logits, shift_labels, ignore_index=-100, reduction="sum")
+
+ def get_influence_tracked_modules(self) -> List[str]:
+ total_modules = []
+
+ # You can uncomment the following lines if you would like to compute influence also on attention layers.
+ # for i in range(32):
+ # total_modules.append(f"model.layers.{i}.self_attn.q_proj")
+ # total_modules.append(f"model.layers.{i}.self_attn.k_proj")
+ # total_modules.append(f"model.layers.{i}.self_attn.v_proj")
+ # total_modules.append(f"model.layers.{i}.self_attn.o_proj")
+
+ for i in range(32):
+ total_modules.append(f"model.layers.{i}.mlp.gate_proj")
+ total_modules.append(f"model.layers.{i}.mlp.up_proj")
+ total_modules.append(f"model.layers.{i}.mlp.down_proj")
+
+ return total_modules
+
+ def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
+ return batch["attention_mask"]
diff --git a/examples/requirements.txt b/examples/requirements.txt
index d7e148a..5fb9eed 100644
--- a/examples/requirements.txt
+++ b/examples/requirements.txt
@@ -5,3 +5,7 @@ tueplots
transformers
evaluate
datasets
+sentencepiece!=0.1.92
+nltk
+py7zr
+rouge-score
\ No newline at end of file
diff --git a/examples/swag/README.md b/examples/swag/README.md
index 3863330..4ff8146 100644
--- a/examples/swag/README.md
+++ b/examples/swag/README.md
@@ -1,7 +1,7 @@
# SWAG & RoBERTa Example
-This directory contains scripts for fine-tuning RoBERTa computing influence scores on the SWAG dataset. The pipeline is motivated from [this HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice).
-To get started, please install the necessary packages:
+This directory demonstrates fine-tuning RoBERTa on the SWAG dataset and computing influence scores. The implementation is inspired by [this HuggingFace example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice) and showcases how to define `post_process_per_sample_gradient`.
+Install the required packages:
```bash
pip install -r requirements.txt
@@ -9,7 +9,7 @@ pip install -r requirements.txt
## Training
-To fine-tune RoBERTa on SWAG, run the following command:
+To fine-tune RoBERTa on the SWAG dataset, run the following command:
```bash
python train.py --checkpoint_dir ./checkpoints \
@@ -21,61 +21,120 @@ python train.py --checkpoint_dir ./checkpoints \
--seed 1004
```
-This will fine-tune the model using the specified hyperparameters and save the final checkpoint in the `./checkpoints` directory.
+The final checkpoint will be saved in the `./checkpoints` directory.
## Computing Pairwise Influence Scores
-To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command:
+To calculate pairwise influence scores on 2000 query data points using `ekfac`, run:
```bash
-python analyze.py --query_batch_size 64 \
+python analyze.py --factor_batch_size 128 \
+ --query_batch_size 100 \
+ --train_batch_size 64 \
+ --checkpoint_dir ./checkpoints \
+ --factor_strategy ekfac
+```
+
+Alternative options for `factor_strategy` include `identity`, `diagonal`, and `kfac`.
+On an A100 (80GB), computing the pairwise scores (including EKFAC factors) takes approximately 10 hours:
+
+```
+----------------------------------------------------------------------------------------------------------------------------------
+| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Total | - | 11 | 3.5124e+04 | 100 % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Compute Pairwise Score | 2.9384e+04 | 1 | 2.9384e+04 | 83.656 |
+| Fit Lambda | 3578.7 | 1 | 3578.7 | 10.189 |
+| Fit Covariance | 2143.9 | 1 | 2143.9 | 6.1036 |
+| Perform Eigendecomposition | 10.213 | 1 | 10.213 | 0.029078 |
+| Save Eigendecomposition | 3.4398 | 1 | 3.4398 | 0.0097933 |
+| Save Covariance | 2.5179 | 1 | 2.5179 | 0.0071684 |
+| Save Pairwise Score | 1.2982 | 1 | 1.2982 | 0.0036959 |
+| Save Lambda | 0.68226 | 1 | 0.68226 | 0.0019424 |
+| Load All Factors | 0.013627 | 1 | 0.013627 | 3.8797e-05 |
+| Load Eigendecomposition | 0.0088496 | 1 | 0.0088496 | 2.5195e-05 |
+| Load Covariance | 0.008222 | 1 | 0.008222 | 2.3408e-05 |
+----------------------------------------------------------------------------------------------------------------------------------
+```
+
+For more efficient computation, use half-precision:
+
+```bash
+python analyze.py --factor_batch_size 128 \
+ --query_batch_size 100 \
--train_batch_size 128 \
--use_half_precision \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
-You can also use `identity`, `diagonal`, and `kfac` for `factor_strategy`. On an A6000 (48GB), it takes roughly 95 minutes to compute the pairwise scores (including computing EKFAC factors):
+This reduces computation time to about 3 hours on an A100 (80GB) GPU.
```
-
+----------------------------------------------------------------------------------------------------------------------------------
+| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Total | - | 11 | 1.0935e+04 | 100 % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Compute Pairwise Score | 9576.4 | 1 | 9576.4 | 87.578 |
+| Fit Lambda | 932.07 | 1 | 932.07 | 8.524 |
+| Fit Covariance | 411.81 | 1 | 411.81 | 3.7661 |
+| Perform Eigendecomposition | 10.623 | 1 | 10.623 | 0.097145 |
+| Save Eigendecomposition | 1.4735 | 1 | 1.4735 | 0.013475 |
+| Save Covariance | 1.2953 | 1 | 1.2953 | 0.011846 |
+| Save Pairwise Score | 0.66271 | 1 | 0.66271 | 0.0060606 |
+| Save Lambda | 0.34022 | 1 | 0.34022 | 0.0031114 |
+| Load All Factors | 0.012041 | 1 | 0.012041 | 0.00011012 |
+| Load Covariance | 0.0079526 | 1 | 0.0079526 | 7.2728e-05 |
+| Load Eigendecomposition | 0.0076841 | 1 | 0.0076841 | 7.0273e-05 |
+----------------------------------------------------------------------------------------------------------------------------------
```
-For more efficient computation, use AMP half precision + query batching + DDP:
+Query batching (low-rank approximation to the query gradient; see **Section 3.2.2** from the paper) can be used to compute influence scores with a larger query batch size:
```bash
-torchrun --standalone --nnodes=1 --nproc-per-node=2 analyze.py --factor_batch_size 128 \
+python analyze.py --factor_batch_size 128 \
--query_batch_size 100 \
--train_batch_size 128 \
- --checkpoint_dir ./checkpoints \
- --factor_strategy ekfac \
--query_gradient_rank 32 \
--use_half_precision \
- --use_ddp
+ --checkpoint_dir ./checkpoints \
+ --factor_strategy ekfac
```
-This reduces computation time to about 20 minutes on an A100 (80GB) GPU:
+On an A100 (80GB) GPU, it takes roughly 1 hour to compute the pairwise scores with query batching:
```
-
+----------------------------------------------------------------------------------------------------------------------------------
+| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Total | - | 3 | 2007.9 | 100 % |
+----------------------------------------------------------------------------------------------------------------------------------
+| Compute Pairwise Score | 2007.2 | 1 | 2007.2 | 99.966 |
+| Save Pairwise Score | 0.66464 | 1 | 0.66464 | 0.033102 |
+| Load All Factors | 0.012345 | 1 | 0.012345 | 0.00061484 |
+----------------------------------------------------------------------------------------------------------------------------------
```
## Evaluating Linear Datamodeling Score
The `evaluate_lds.py` script computes the [linear datamodeling score (LDS)](https://arxiv.org/abs/2303.14186). It measures the LDS obtained by
retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks).
-We obtain `xx` LDS (we get `xx` LDS with the AMP half precision + query batching + DDP).
-
-The script can also print top influential sequences for a given query.
+We obtain `0.33` LDS (`0.30` LDS with half precision and half precision + query batching).
```
-Query Example:
- Sentence1: The west has preferred to focus on endangered animals, rather than endangered humans. African elephants are hunted down and stripped of tusks and hidden by poachers. Their numbers in Africa slumped from 1.2m to 600,000 in a decade until CITES - the Convention on International Trade in Endangered Species - banned the trade in ivory.
- Sentence2: African elephants are endangered by ivory poachers.
- Label: 0
-
+Query Data Example:
+ Option 0: He looks disgusted and spits it out onto the plate.He slides both hands around the crack.
+ Option 1: He looks disgusted and spits it out onto the plate.He passes someone to the bald guy.
+ Option 2: He looks disgusted and spits it out onto the plate.He picks up a piece of bread.