You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This corresponds to **Equation 20** in the paper. You can tune:
303
302
-`lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices.
304
-
-`lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
305
-
-`lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
306
-
-`cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
307
-
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
308
-
-`lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
303
+
-`lambda_data_partitions`: Number of data partitions to use for computing Lambda matrices.
304
+
-`lambda_module_partitions`: Number of module partitions to use for computing Lambda matrices.
305
+
-`offload_activations_to_cpu`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
306
+
You can set `offload_activations_to_cpu=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
307
+
-`use_iterative_lambda_aggregation`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
309
308
This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
310
309
-`per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16`
311
310
or `torch.float16`.
312
311
-`lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
313
-
or `torch.float16`. Recommended to use `torch.float32`.
312
+
or `torch.float16`.
314
313
315
314
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
316
315
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
317
-
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.)
316
+
2. Try setting `use_iterative_lambda_aggregation=True` or `offload_activations_to_cpu=True`. (Try out `use_iterative_lambda_aggregation=True` first.)
318
317
3. Try using lower precision for `per_sample_gradient_dtype` and `lambda_dtype`.
319
-
4. Try using `lambda_module_partition_size > 1`.
318
+
4. Try using `lambda_module_partitions > 1`.
320
319
321
320
### FAQs
322
321
@@ -339,21 +338,24 @@ import torch
339
338
from kronfluence.arguments import ScoreArguments
340
339
341
340
score_args = ScoreArguments(
342
-
damping=1e-08,
343
-
cached_activation_cpu_offload=False,
344
-
distributed_sync_steps=1000,
341
+
damping_factor=1e-08,
345
342
amp_dtype=None,
343
+
offload_activations_to_cpu=False,
346
344
347
345
# More functionalities to compute influence scores.
-`damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
367
+
-`damping_factor`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
366
368
`(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296).
367
-
-`cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
368
369
-`amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
369
-
-`data_partition_size`: Number of data partitions for computing influence scores.
370
-
-`module_partition_size`: Number of module partitions for computing influence scores.
371
-
-`per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
370
+
-`offload_activations_to_cpu`: Whether to offload cached activations to CPU.
371
+
-`data_partitions`: Number of data partitions for computing influence scores.
372
+
-`module_partitions`: Number of module partitions for computing influence scores.
373
+
-`compute_per_module_scores`: Whether to return a per-module influence scores. Instead of summing over influences across
372
374
all modules, this will keep track of intermediate module-wise scores.
373
-
--`use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
374
-
-`query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
375
+
-`compute_per_token_scores`: Whether to return a per-token influence scores. Only applicable to transformer-based models.
376
+
-`aggregate_query_gradients`: Whether to use the summed query gradient instead of per-sample query gradients.
377
+
-`aggregate_train_gradients`: Whether to use the summed training gradient instead of per-sample training gradients.
378
+
-`use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
379
+
-`query_gradient_low_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
375
380
-`query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
376
-
-`num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with
381
+
-`query_gradient_accumulation_steps`: Number of query gradients to accumulate over. For example, when `query_gradient_accumulation_steps=2` with
377
382
`query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
378
383
-`score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
379
384
-`per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
380
-
-`precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
381
-
but `torch.float32` is recommended.
385
+
-`precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`.
382
386
383
387
### Computing Influence Scores
384
388
@@ -409,12 +413,12 @@ vector will correspond to `g_m^T ⋅ H^{-1} ⋅ g_l`, where `g_m` is the gradien
409
413
410
414
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
411
415
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
Copy file name to clipboardExpand all lines: README.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -182,7 +182,7 @@ Please address any reported issues before submitting your PR.
182
182
## Acknowledgements
183
183
184
184
[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
185
-
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
185
+
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
0 commit comments