diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index b844a53..f0bcb2e 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -28,8 +28,10 @@ jobs: pytest -vx tests/test_dataset_utils.py pytest -vx tests/test_testable_tasks.py pytest -vx tests/factors/test_covariances.py - pytest -vx tests/factors/test_eigens.py + pytest -vx tests/factors/test_eigendecompositions.py + pytest -vx tests/factors/test_lambdas.py pytest -vx tests/modules/test_modules.py pytest -vx tests/modules/test_per_sample_gradients.py + pytest -vx tests/modules/test_matmul.py pytest -vx tests/scores/test_pairwise_scores.py pytest -vx tests/scores/test_self_scores.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index debad09..f4b18f5 100644 --- a/.gitignore +++ b/.gitignore @@ -164,8 +164,8 @@ cython_debug/ # Checkpoints and influence outputs checkpoints/ -analyses/ influence_results/ data/ +cache/ *.pth *.pt \ No newline at end of file diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index b0f1e5c..cdc5cf1 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -67,7 +67,7 @@ class YourTask(Task): ) -> torch.Tensor: # TODO: Complete this method. - def tracked_modules(self) -> Optional[List[str]]: + def get_influence_tracked_modules(self) -> Optional[List[str]]: # TODO: [Optional] Complete this method. return None # Compute influence scores on all available modules. @@ -89,7 +89,7 @@ model = prepare_model(model=model, task=task) ... ``` -If you have specified specific module names in `Task.tracked_modules`, `TrackedModule` will only be installed for these modules. +If you have specified specific module names in `Task.get_influence_tracked_modules`, `TrackedModule` will only be installed for these modules. **\[Optional\] Create a DDP and FSDP Module.** After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or @@ -140,7 +140,7 @@ Try rewriting the model so that it uses supported modules (as done for the `conv Alternatively, you can create a subclass of `TrackedModule` to compute influence scores for your custom module. If there are specific modules you would like to see supported, please submit an issue. -**How should I write task.tracked_modules?** +**How should I write task.get_influence_tracked_modules?** We recommend using all supported modules for influence computations. However, if you would like to compute influence scores on subset of the modules (e.g., influence computations only on MLP layers for transformer or influence computation only on the last layer), inspect `model.named_modules()` to determine what modules to use. You can specify the list of module names you want to analyze. @@ -183,7 +183,7 @@ def forward(x: torch.Tensor) -> torch.Tensor: > [!WARNING] > The default arguments assume the module is used only once during the forward pass. > If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set -> `shared_parameters_exist=True` in `FactorArguments`. +> `has_shared_parameters=True` in `FactorArguments`. **Why are there so many arguments?** Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments` @@ -204,14 +204,13 @@ from kronfluence.arguments import FactorArguments factor_args = FactorArguments( strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac". use_empirical_fisher=False, - distributed_sync_steps=1000, amp_dtype=None, - shared_parameters_exist=False, + has_shared_parameters=False, # Settings for covariance matrix fitting. covariance_max_examples=100_000, - covariance_data_partition_size=1, - covariance_module_partition_size=1, + covariance_data_partitions=1, + covariance_module_partitions=1, activation_covariance_dtype=torch.float32, gradient_covariance_dtype=torch.float32, @@ -220,10 +219,10 @@ factor_args = FactorArguments( # Settings for Lambda matrix fitting. lambda_max_examples=100_000, - lambda_data_partition_size=1, - lambda_module_partition_size=1, - lambda_iterative_aggregate=False, - cached_activation_cpu_offload=False, + lambda_data_partitions=1, + lambda_module_partitions=1, + use_iterative_lambda_aggregation=False, + offload_activations_to_cpu=False, per_sample_gradient_dtype=torch.float32, lambda_dtype=torch.float32, ) @@ -237,7 +236,7 @@ You can change: - `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch) instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`. - `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`. -- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass. +- `has_shared_parameters`: Specifies whether the shared parameters exist in the forward pass. ### Fitting Covariance Matrices @@ -254,13 +253,13 @@ covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_fa This step corresponds to **Equation 16** in the paper. You can tune: - `covariance_max_examples`: Controls the maximum number of data points for fitting covariance matrices. Setting it to `None`, Kronfluence computes covariance matrices for all data points. -- `covariance_data_partition_size`: Number of data partitions to use for computing covariance matrices. -For example, when `covariance_data_partition_size = 2`, the dataset is split into 2 chunks and covariance matrices +- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices. +For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter). -- `covariance_module_partition_size`: Number of module partitions to use for computing covariance matrices. -For example, when `covariance_module_partition_size = 2`, the module is split into 2 chunks and covariance matrices +- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices. +For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total covariance matrices cannot fit into GPU memory). However, this will require multiple iterations over the dataset and can be slow. - `activation_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16` @@ -271,7 +270,7 @@ or `torch.float16`. **Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors. 1. Try reducing the `per_device_batch_size` when fitting covariance matrices. 2. Try using lower precision for `activation_covariance_dtype` and `gradient_covariance_dtype`. -3. Try setting `covariance_module_partition_size > 1`. +3. Try setting `covariance_module_partitions > 1`. ### Performing Eigendecomposition @@ -301,22 +300,22 @@ lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor") This corresponds to **Equation 20** in the paper. You can tune: - `lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices. -- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices. -- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices. -- `cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory. -You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower. -- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications. +- `lambda_data_partitions`: Number of data partitions to use for computing Lambda matrices. +- `lambda_module_partitions`: Number of module partitions to use for computing Lambda matrices. +- `offload_activations_to_cpu`: Computing the per-sample-gradient requires saving the intermediate activation in memory. +You can set `offload_activations_to_cpu=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower. +- `use_iterative_lambda_aggregation`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications. This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient. - `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16` or `torch.float16`. - `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16` -or `torch.float16`. Recommended to use `torch.float32`. +or `torch.float16`. **Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors. 1. Try reducing the `per_device_batch_size` when fitting Lambda matrices. -2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.) +2. Try setting `use_iterative_lambda_aggregation=True` or `offload_activations_to_cpu=True`. (Try out `use_iterative_lambda_aggregation=True` first.) 3. Try using lower precision for `per_sample_gradient_dtype` and `lambda_dtype`. -4. Try using `lambda_module_partition_size > 1`. +4. Try using `lambda_module_partitions > 1`. ### FAQs @@ -339,21 +338,24 @@ import torch from kronfluence.arguments import ScoreArguments score_args = ScoreArguments( - damping=1e-08, - cached_activation_cpu_offload=False, - distributed_sync_steps=1000, + damping_factor=1e-08, amp_dtype=None, + offload_activations_to_cpu=False, # More functionalities to compute influence scores. - data_partition_size=1, - module_partition_size=1, - per_module_score=False, + data_partitions=1, + module_partitions=1, + compute_per_module_scores=False, + compute_per_token_scores=False, use_measurement_for_self_influence=False, + aggregate_query_gradients=False, + aggregate_train_gradients=False, # Configuration for query batching. - query_gradient_rank=None, + query_gradient_low_rank=None, + use_full_svd=False, query_gradient_svd_dtype=torch.float32, - num_query_gradient_accumulations=1, + query_gradient_accumulation_steps=1, # Configuration for dtype. score_dtype=torch.float32, @@ -362,23 +364,25 @@ score_args = ScoreArguments( ) ``` -- `damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues +- `damping_factor`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues `(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296). -- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU. - `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`. -- `data_partition_size`: Number of data partitions for computing influence scores. -- `module_partition_size`: Number of module partitions for computing influence scores. -- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across +- `offload_activations_to_cpu`: Whether to offload cached activations to CPU. +- `data_partitions`: Number of data partitions for computing influence scores. +- `module_partitions`: Number of module partitions for computing influence scores. +- `compute_per_module_scores`: Whether to return a per-module influence scores. Instead of summing over influences across all modules, this will keep track of intermediate module-wise scores. -- - `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores. -- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used. +- `compute_per_token_scores`: Whether to return a per-token influence scores. Only applicable to transformer-based models. +- `aggregate_query_gradients`: Whether to use the summed query gradient instead of per-sample query gradients. +- `aggregate_train_gradients`: Whether to use the summed training gradient instead of per-sample training gradients. +- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores. +- `query_gradient_low_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used. - `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`. -- `num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with +- `query_gradient_accumulation_steps`: Number of query gradients to accumulate over. For example, when `query_gradient_accumulation_steps=2` with `query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients. - `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`. - `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`. -- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`, -but `torch.float32` is recommended. +- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`. ### Computing Influence Scores @@ -409,12 +413,12 @@ vector will correspond to `g_m^T ⋅ H^{-1} ⋅ g_l`, where `g_m` is the gradien **Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors. 1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`. -2. Try setting `cached_activation_cpu_offload=True`. +2. Try setting `offload_activations_to_cpu=True`. 3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`. 4. Try using lower precision for `precondition_dtype`. -5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query +5. Try setting `query_gradient_low_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query batching is only supported for computing pairwise influence scores, not self-influence scores. -6. Try setting `module_partition_size > 1`. +6. Try setting `module_partitions > 1`. ### FAQs diff --git a/README.md b/README.md index 9711087..ec6b611 100644 --- a/README.md +++ b/README.md @@ -182,7 +182,7 @@ Please address any reported issues before submitting your PR. ## Acknowledgements [Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations. -I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback. +I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback. ## License diff --git a/dev_requirements.txt b/dev_requirements.txt index 7c16350..5d23191 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,10 +4,11 @@ accelerate>=0.31.0 einops>=0.8.0 einconv>=0.1.0 opt_einsum>=3.3.0 +scikit-learn>=1.4.0 safetensors>=0.4.2 tqdm>=4.66.4 datasets>=2.20.0 -transformers>=4.41.2 +transformers>=4.42.0 isort==5.13.2 pylint==3.2.3 pytest==8.2.2 diff --git a/examples/README.md b/examples/README.md index 4d58a09..328a50f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -12,22 +12,22 @@ pip install -r requirements.txt Alternatively, navigate to each example folder and run `pip install -r requirements.txt`. - ## List of Tasks Our examples cover the following tasks:
-| Task | Example datasets | +| Task | Example Datasets | |----------------------|:------------------------:| | Regression | UCI | -| Image Classification | CIFAR-10 / ImageNet | +| Image Classification | CIFAR-10 & ImageNet | | Text Classification | GLUE | | Multiple-Choice | SWAG | -| Language Modeling | WikiText-2 / OpenWebText | +| Summarization | DNN/DailyMail | +| Language Modeling | WikiText-2 & OpenWebText |
These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel). -Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue. +Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue. \ No newline at end of file diff --git a/examples/cifar/README.md b/examples/cifar/README.md index c3cda87..5474295 100644 --- a/examples/cifar/README.md +++ b/examples/cifar/README.md @@ -26,7 +26,7 @@ This will train the model using the specified hyperparameters and save the train ## Computing Pairwise Influence Scores -To compute pairwise influence scores on 2000 query data points using the `ekfac` factorization strategy, run the following command: +To compute pairwise influence scores on 2000 query data points using the `ekfac` strategy, run the following command: ```bash python analyze.py --query_batch_size 1000 \ @@ -41,23 +41,23 @@ In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as t ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 112.83 | 100 % | +| Total | - | 11 | 106.38 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 47.989 | 1 | 47.989 | 42.532 | -| Fit Lambda | 34.639 | 1 | 34.639 | 30.7 | -| Fit Covariance | 21.841 | 1 | 21.841 | 19.357 | -| Save Pairwise Score | 3.5998 | 1 | 3.5998 | 3.1905 | -| Perform Eigendecomposition | 2.7724 | 1 | 2.7724 | 2.4572 | -| Save Covariance | 0.85695 | 1 | 0.85695 | 0.75951 | -| Save Eigendecomposition | 0.85628 | 1 | 0.85628 | 0.75892 | -| Save Lambda | 0.12327 | 1 | 0.12327 | 0.10925 | -| Load Eigendecomposition | 0.056494 | 1 | 0.056494 | 0.05007 | -| Load All Factors | 0.048981 | 1 | 0.048981 | 0.043412 | -| Load Covariance | 0.046798 | 1 | 0.046798 | 0.041476 | +| Compute Pairwise Score | 46.745 | 1 | 46.745 | 43.941 | +| Fit Lambda | 34.885 | 1 | 34.885 | 32.793 | +| Fit Covariance | 22.538 | 1 | 22.538 | 21.187 | +| Perform Eigendecomposition | 0.91424 | 1 | 0.91424 | 0.85941 | +| Save Pairwise Score | 0.81219 | 1 | 0.81219 | 0.76348 | +| Save Covariance | 0.22351 | 1 | 0.22351 | 0.21011 | +| Save Eigendecomposition | 0.21617 | 1 | 0.21617 | 0.20321 | +| Save Lambda | 0.031038 | 1 | 0.031038 | 0.029177 | +| Load Eigendecomposition | 0.010442 | 1 | 0.010442 | 0.0098156 | +| Load All Factors | 0.0026517 | 1 | 0.0026517 | 0.0024927 | +| Load Covariance | 0.0016419 | 1 | 0.0016419 | 0.0015435 | ---------------------------------------------------------------------------------------------------------------------------------- ``` -To use AMP when computing influence scores (in addition to half precision when computing influence factors and scores), run: +To use AMP when computing influence scores, run: ```bash python analyze.py --query_batch_size 1000 \ @@ -73,19 +73,19 @@ This reduces computation time to about 40 seconds on an A100 (80GB) GPU: ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 42.316 | 100 % | +| Total | - | 11 | 35.965 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 19.565 | 1 | 19.565 | 46.235 | -| Fit Lambda | 9.173 | 1 | 9.173 | 21.677 | -| Fit Covariance | 7.3723 | 1 | 7.3723 | 17.422 | -| Perform Eigendecomposition | 2.6613 | 1 | 2.6613 | 6.2891 | -| Save Pairwise Score | 2.0156 | 1 | 2.0156 | 4.7633 | -| Save Covariance | 0.71699 | 1 | 0.71699 | 1.6944 | -| Save Eigendecomposition | 0.52561 | 1 | 0.52561 | 1.2421 | -| Load Covariance | 0.15732 | 1 | 0.15732 | 0.37177 | -| Save Lambda | 0.063394 | 1 | 0.063394 | 0.14981 | -| Load Eigendecomposition | 0.051395 | 1 | 0.051395 | 0.12146 | -| Load All Factors | 0.014144 | 1 | 0.014144 | 0.033425 | +| Compute Pairwise Score | 18.012 | 1 | 18.012 | 50.082 | +| Fit Lambda | 9.2271 | 1 | 9.2271 | 25.656 | +| Fit Covariance | 7.134 | 1 | 7.134 | 19.836 | +| Perform Eigendecomposition | 0.87962 | 1 | 0.87962 | 2.4457 | +| Save Pairwise Score | 0.45432 | 1 | 0.45432 | 1.2632 | +| Save Covariance | 0.12861 | 1 | 0.12861 | 0.35759 | +| Save Eigendecomposition | 0.11296 | 1 | 0.11296 | 0.31407 | +| Save Lambda | 0.010712 | 1 | 0.010712 | 0.029784 | +| Load All Factors | 0.002736 | 1 | 0.002736 | 0.0076074 | +| Load Covariance | 0.0016696 | 1 | 0.0016696 | 0.0046421 | +| Load Eigendecomposition | 0.0014892 | 1 | 0.0014892 | 0.0041406 | ---------------------------------------------------------------------------------------------------------------------------------- ``` @@ -131,19 +131,19 @@ On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 122.28 | 100 % | +| Total | - | 11 | 121.85 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Self-Influence Score | 61.999 | 1 | 61.999 | 50.701 | -| Fit Lambda | 34.629 | 1 | 34.629 | 28.319 | -| Fit Covariance | 21.807 | 1 | 21.807 | 17.833 | -| Perform Eigendecomposition | 1.8041 | 1 | 1.8041 | 1.4754 | -| Save Covariance | 0.86378 | 1 | 0.86378 | 0.70638 | -| Save Eigendecomposition | 0.84935 | 1 | 0.84935 | 0.69458 | -| Save Lambda | 0.18367 | 1 | 0.18367 | 0.1502 | -| Load Eigendecomposition | 0.052867 | 1 | 0.052867 | 0.043233 | -| Load Covariance | 0.051723 | 1 | 0.051723 | 0.042298 | -| Load All Factors | 0.031986 | 1 | 0.031986 | 0.026158 | -| Save Self-Influence Score | 0.010352 | 1 | 0.010352 | 0.0084653 | +| Compute Self-Influence Score | 62.778 | 1 | 62.778 | 51.519 | +| Fit Lambda | 35.174 | 1 | 35.174 | 28.866 | +| Fit Covariance | 22.582 | 1 | 22.582 | 18.532 | +| Perform Eigendecomposition | 0.82656 | 1 | 0.82656 | 0.67832 | +| Save Covariance | 0.2478 | 1 | 0.2478 | 0.20336 | +| Save Eigendecomposition | 0.22042 | 1 | 0.22042 | 0.18088 | +| Save Lambda | 0.018463 | 1 | 0.018463 | 0.015152 | +| Load All Factors | 0.0027554 | 1 | 0.0027554 | 0.0022612 | +| Load Covariance | 0.0016607 | 1 | 0.0016607 | 0.0013628 | +| Load Eigendecomposition | 0.0015408 | 1 | 0.0015408 | 0.0012645 | +| Save Self-Influence Score | 0.0010841 | 1 | 0.0010841 | 0.00088966 | ---------------------------------------------------------------------------------------------------------------------------------- ``` diff --git a/examples/cifar/half_precision_analysis.py b/examples/cifar/half_precision_analysis.py index c62c14b..3f9a41d 100644 --- a/examples/cifar/half_precision_analysis.py +++ b/examples/cifar/half_precision_analysis.py @@ -1,6 +1,8 @@ import logging import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import spearmanr from tueplots import markers from kronfluence.analyzer import Analyzer @@ -25,13 +27,19 @@ def main(): plt.rcParams["axes.axisbelow"] = True # Only plot first 3000 points to avoid clutter. - idx = 0 + idx = 79 plt.scatter(half_scores[idx][:3000], scores[idx][:3000], edgecolor="k") plt.grid() plt.xlabel("bfloat16") plt.ylabel("float32") plt.show() + # Compute the averaged spearman correlation. + all_corr = [] + for i in range(100): + all_corr.append(spearmanr(scores[i], half_scores[i])[0]) + logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}") + if __name__ == "__main__": main() diff --git a/examples/cifar/requirements.txt b/examples/cifar/requirements.txt index 5a65422..a667c1f 100644 --- a/examples/cifar/requirements.txt +++ b/examples/cifar/requirements.txt @@ -1,4 +1,3 @@ scikit-learn -jupyter matplotlib tueplots \ No newline at end of file diff --git a/examples/dailymail/README.md b/examples/dailymail/README.md new file mode 100644 index 0000000..a7e5ab6 --- /dev/null +++ b/examples/dailymail/README.md @@ -0,0 +1,73 @@ +# CNN/DailyMail & T5 Example + +This directory contains scripts for fine-tuning T5 and computing influence scores on the CNN/DailyMail dataset. The pipeline is motivated from [this HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization). +To begin, install the necessary packages: + +```bash +pip install -r requirements.txt +``` + +## Training + +To fine-tune T5 on CNN/DailyMail, run the following command: + +```bash +python train.py --checkpoint_dir ./checkpoints \ + --train_batch_size 16 \ + --eval_batch_size 32 \ + --learning_rate 5e-05 \ + --weight_decay 0.01 \ + --num_train_epochs 3 \ + --seed 1004 +``` + +This will fine-tune the model using the specified hyperparameters and save the final checkpoint in the `./checkpoints` directory. + +## Computing Pairwise Influence Scores + +To calculate pairwise influence scores on 10 query data points using `ekfac`, run: + +```bash +python analyze.py --factor_batch_size 64 \ + --query_batch_size 10 \ + --train_batch_size 128 \ + --use_half_precision \ + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac +``` + +Alternative options for `factor_strategy` include `identity`, `diagonal`, and `kfac`. On an A100 (80GB), computing the pairwise scores (including EKFAC factors) takes approximately 1 hour: + +``` +---------------------------------------------------------------------------------------------------------------------------------- +| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | +---------------------------------------------------------------------------------------------------------------------------------- +| Total | - | 11 | 3397.1 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 1905.6 | 1 | 1905.6 | 56.093 | +| Fit Lambda | 747.5 | 1 | 747.5 | 22.004 | +| Fit Covariance | 734.03 | 1 | 734.03 | 21.607 | +| Perform Eigendecomposition | 8.4236 | 1 | 8.4236 | 0.24796 | +| Save Eigendecomposition | 0.79164 | 1 | 0.79164 | 0.023303 | +| Save Covariance | 0.60366 | 1 | 0.60366 | 0.01777 | +| Save Lambda | 0.1514 | 1 | 0.1514 | 0.0044566 | +| Load All Factors | 0.027977 | 1 | 0.027977 | 0.00082354 | +| Save Pairwise Score | 0.01082 | 1 | 0.01082 | 0.00031851 | +| Load Covariance | 0.010015 | 1 | 0.010015 | 0.0002948 | +| Load Eigendecomposition | 0.0096806 | 1 | 0.0096806 | 0.00028497 | +---------------------------------------------------------------------------------------------------------------------------------- +``` + +## Inspecting Top Influential Sequences + +The `inspect_examples.py` script prints top influential sequences for a given query. + +``` +Query Data Example: + Input: summarize: (CNN)My vote for Father of the Year goes to Curt Schilling. The former Major League Baseball pitcher recently fired off a series of fastballs and mowed down a group of Twitter trolls who made the mistake of tweeting vulgar and sexually-explicit comments about Schilling's teenage daughter. The drama started, innocently enough, on February 25, when Schilling played the role of a proud father. He sent a tweet congratulating his daughter, Gabby, on being accepted to Salve Regina University, where she'll play softball. It read: "Congrats to Gabby Schilling who will pitch for the Salve Regina Seahawks next year!! — Curt Schilling (@gehrig38)" Almost immediately, responses came in from young men, complete strangers who apparently followed Schilling on Twitter. The tweets quickly went from immature, to creepy, to repugnant. Threats of rape were common. The tweets were deleted, and the accounts were closed after this story went viral. But not before Schilling captured some of the images and posted them on his blog. What was said about 17-year-old Gabby Schilling wasn't just obnoxious. It was vile and obscene. What was said wasn't just mean and ugly. It was threatening and scary. As a parent, it's the kind of thing that makes you rethink your opposition to public caning as a logical punishment for such transgressions. These misogynistic cowards may have thought they could hide in the darkness of anonymity, the sort that many have come to expect from social media sites, where you feel free to be a despicable human being because, you think, no one will ever find out who you really are and hold you accountable for your words. If so, they thought wrong. They couldn't hide. They were found out, and they got the throttling they so richly deserved. Thanks to dad. According to Schilling, who made it his mission to track down these cretins and make sure those they associate with know who they really are, two people have already paid a price due to their tweets. One was a student disc jockey at a community college in New Jersey, who was suspended, and the other was a part-time ticket seller for the New York Yankees, who was fired. Concerned that this is an example of exactly the kind of cyberbullying that leads some teenagers to commit suicide, Schilling is also thinking about taking legal action against some of the other people involved. Bravo for him. I'm sure that, all across America, dads with daughters -- after reading some of the horrible things that were said about this young girl -- are marveling at Schilling's self-control. I have two daughters of my own, and he's a better man than me. If ever there was a case where profanity-spewing malcontents deserved to have their mouths washed out with soap, this is it. So what additional insights can we draw, and what larger lessons can we learn, from this unexpected but predictable collision of old-fashioned parenthood and newfangled media? There are a few. The first is about accountability, the very thing that the young men who posted these hurtful messages were trying to avoid. But Schilling wouldn't let them. At their best, social media sites like Twitter, Facebook, Instagram and others allow the sharing the information and the building of a sense of community. At their worst, they become digital sandboxes and locker rooms where people think have a license to misbehave without having to worry about consequences. We need to applaud efforts like this that promote greater online accountability. There's also something to be said about protective parents, and how essential they are to a working society. We should still be concerned about those overprotective parents who hover like helicopters from little league to job interviews. We shouldn't bubblewrap our kids, and keep them from playing outdoors, and then sit around wondering why they're soft, timid, and risk-averse. But protective parents -- the kind who shield their kids from real danger -- never go out of style. A parent's top job is to protect his children. Schilling did his job. Finally, it's worth reminding everyone that freedom of expression does not mean freedom from rules, standards, and expectations that should guide your behavior. There are things you don't say. There are boundaries, ways that we expect you to behave so you don't terrorize other people or bring shame upon yourself, your friends, and + Label: Ruben Navarrette: Schilling deserves praise for taking on online haters for offensive comments about his daughter. Navarrette: In protecting his child, Schilling set a model for parenting and taught us a lesson about social media. + +Top Influential Example: + Input: summarize: (CNN) -- What is it with juries in high-profile cases in Southern California? Over the years, they've become a national joke. But no one is laughing. Instead, with each travesty of justice and every acquittal that should have been a conviction, you're left wondering just what trial these 12 folks were watching and questioning whether we should have higher standards for who sits on a jury. Sometimes, the juries in local and state courts get it wrong, and the Justice Department must step in and make it right. Think back to the acquittal in April 1992 of four Los Angeles police officers who, one year earlier, savagely beat motorist Rodney King. They walked out of a courtroom in Simi Valley, California, as free men -- sparking days of rioting, looting and violence. At the time, the conventional thinking on newspaper editorial pages and on talk radio was the jurors in that largely white suburb of Los Angeles, which was itself home to many active-duty and retired police officers, saw the police force as their line of defense against undesirables like King. So naturally, the argument went, they would cut them some slack. The officers were tried again, and convicted in federal court of violating King's civil rights. Justice was finally served. Here we go again. There hasn't been much civil unrest over what happened to Kelly Thomas, the homeless and mentally ill man who -- on July 5, 2011 -- was beaten to death by a swarm of police officers in Fullerton, California. But now that the verdict is in, literally, on the two former officers who were charged in his death, there is plenty of outrage on talk radio, online and in other public forums. Another 12 people who swore an oath to consider the evidence and the law and make sure that justice is served appear to have gotten it terribly wrong. This week, that jury in Santa Ana, California -- a city about 30 miles southeast of Los Angeles -- produced a wave of gasps in the courtroom when it announced that it had found Manuel Ramos, who had been charged with second-degree murder and involuntary manslaughter, and Jay Cicinelli, who was charged with involuntary manslaughter and excessive use of force, not guilty on all counts. What? The beating was caught on a surveillance tape. When you watch those 33 minutes of footage, assuming you can stomach the experience, it's hard to believe that anyone could declare the perpetrators "not guilty." The surveillance camera footage shows Thomas being beaten and stunned with a Taser by police until he was unrecognizable and unconscious. You see a defenseless and compliant young man screaming in pain, saying he's sorry and pleading for help from his father. His words will haunt you, "Daddy, help! They're killing me!" According to prosecutors, the young man suffered brain injuries, facial fractures, broken ribs and extensive bruises and abrasions. He wound up lying in a pool of blood. He died five days later. This was not a by-the-book case of police officers using all necessary force to subdue a suspect who was resisting arrest -- a suspect, by the way, who had committed no crime. This was not, as Ramos' attorney claimed, a case of police offices simply "doing their job" with "no malice in their heart." Check the video. Early on in the confrontation, Ramos appears to tell the young man who is sitting on the ground: "You see my fists? They're getting ready to f--- you up!" Another officer is heard telling a comrade: "We ran out of options so I got to the end of my Taser and I... smashed his face to hell." There is the malice. This was abuse of power and an instance of bullying behind a badge. It happens more than we'd like to think in America. But this time, it went too far. And a man died, and a family was shattered. Yet, the jury somehow missed all this? How does this happen? In Los Angeles, people are saying that the mentally ill are the new Rodney King. In the same way that the jury in Simi Valley was inclined to back the officers who it saw as protecting them from people like King, now the jury in Santa Ana is backing the officers who it counts on to prod people like Thomas to move along, leave the streets, and get out of sight. It's a plausible explanation + Label: Ruben Navarrette: Too many high-profile cases in California produce travesties of justice. He says jury acquitted two ex-cops in malicious beating death captured on video. He says case showed abuse of power, bullying behind badge; happens too often in U.S. Navarrette: Only one place left that can right this wrong: The Justice Department. +``` \ No newline at end of file diff --git a/examples/dailymail/__init__.py b/examples/dailymail/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/dailymail/analyze.py b/examples/dailymail/analyze.py new file mode 100644 index 0000000..032cc08 --- /dev/null +++ b/examples/dailymail/analyze.py @@ -0,0 +1,275 @@ +import argparse +import logging +import os +from typing import Dict, List + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import DataCollatorForSeq2Seq + +from examples.dailymail.pipeline import ( + construct_t5, + get_dailymail_dataset, + get_tokenizer, +) +from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.task import Task +from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments +from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments +from kronfluence.utils.dataset import DataLoaderKwargs +from kronfluence.utils.model import apply_ddp + +BATCH_TYPE = Dict[str, torch.Tensor] +try: + LOCAL_RANK = int(os.environ["LOCAL_RANK"]) + WORLD_RANK = int(os.environ["RANK"]) + WORLD_SIZE = int(os.environ["WORLD_SIZE"]) +except KeyError: + LOCAL_RANK = WORLD_RANK = WORLD_SIZE = 0 + + +def parse_args(): + parser = argparse.ArgumentParser(description="Influence analysis on CNN/DailyMail dataset.") + + parser.add_argument( + "--checkpoint_dir", + type=str, + default="./checkpoints", + help="A path that is storing the final checkpoint of the model.", + ) + + parser.add_argument( + "--factor_strategy", + type=str, + default="ekfac", + help="Strategy to compute influence factors.", + ) + parser.add_argument( + "--query_gradient_rank", + type=int, + default=-1, + help="Rank for the low-rank query gradient approximation.", + ) + parser.add_argument( + "--use_half_precision", + action="store_true", + default=False, + help="Whether to use half precision for computing factors and scores.", + ) + parser.add_argument( + "--use_ddp", + action="store_true", + default=False, + help="Whether to use DDP for computing factors and scores.", + ) + parser.add_argument( + "--factor_batch_size", + type=int, + default=64, + help="Batch size for computing influence factors.", + ) + parser.add_argument( + "--query_batch_size", + type=int, + default=10, + help="Batch size for computing query gradients.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=128, + help="Batch size for computing training gradients.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Boolean flag to profile computations.", + ) + args = parser.parse_args() + + if args.checkpoint_dir is not None: + os.makedirs(args.checkpoint_dir, exist_ok=True) + + return args + + +class SummarizationTask(Task): + def compute_train_loss( + self, + batch: BATCH_TYPE, + model: nn.Module, + sample: bool = False, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + decoder_input_ids=batch["decoder_input_ids"], + ).logits + + if not sample: + return F.cross_entropy( + logits.view(-1, logits.size(-1)), batch["labels"].view(-1), ignore_index=-100, reduction="sum" + ) + with torch.no_grad(): + probs = torch.nn.functional.softmax(logits.view(-1, logits.size(-1)).detach(), dim=-1) + sampled_labels = torch.multinomial( + probs, + num_samples=1, + ).flatten() + masks = batch["labels"].view(-1) == -100 + sampled_labels[masks] = -100 + return F.cross_entropy(logits.view(-1, logits.size(-1)), sampled_labels, reduction="sum") + + def compute_measurement( + self, + batch: BATCH_TYPE, + model: nn.Module, + ) -> torch.Tensor: + # Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + decoder_input_ids=batch["decoder_input_ids"], + ).logits + logits = logits.view(-1, logits.size(-1)) + + labels = batch["labels"].view(-1) + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) + logits_correct = logits[bindex, labels] + + cloned_logits = logits.clone() + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) + + margins = logits_correct - cloned_logits.logsumexp(dim=-1) + masks = batch["labels"].view(-1) != -100 + return -margins[masks].sum() + + def get_influence_tracked_modules(self) -> List[str]: + total_modules = [] + + # Add attention layers: + for i in range(6): + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.q") + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.k") + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.v") + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.o") + + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.q") + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.k") + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.v") + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.o") + + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.q") + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.k") + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.v") + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.o") + + # Add MLP layers: + for i in range(6): + total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wi") + total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wo") + + total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wi") + total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wo") + + return total_modules + + def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: + return batch["attention_mask"] + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.INFO) + + # Prepare the dataset. + train_dataset = get_dailymail_dataset( + split="eval_train", + ) + eval_dataset = get_dailymail_dataset( + split="valid", + ) + tokenizer = get_tokenizer() + + # Prepare the trained model. + model = construct_t5() + + # Define task and prepare model. + task = SummarizationTask() + model = prepare_model(model, task) + + if args.use_ddp: + model = apply_ddp( + model=model, + local_rank=LOCAL_RANK, + rank=WORLD_RANK, + world_size=WORLD_SIZE, + ) + + analyzer = Analyzer( + analysis_name="dailymail", + model=model, + task=task, + profile=args.profile, + ) + # Configure parameters for DataLoader. + label_pad_token_id = -100 + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=None, + ) + + dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=data_collator) + analyzer.set_dataloader_kwargs(dataloader_kwargs) + + # Compute influence factors. + factors_name = args.factor_strategy + factor_args = FactorArguments(strategy=args.factor_strategy) + if args.use_half_precision: + factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16) + factors_name += "_half" + if args.use_ddp: + factors_name += "_ddp" + analyzer.fit_all_factors( + factors_name=factors_name, + dataset=train_dataset, + per_device_batch_size=args.factor_batch_size, + factor_args=factor_args, + overwrite_output_dir=False, + ) + + # Compute pairwise scores. + score_args = ScoreArguments() + scores_name = factor_args.strategy + if args.use_half_precision: + score_args = all_low_precision_score_arguments(dtype=torch.bfloat16) + scores_name += "_half" + rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None + if rank is not None: + score_args.query_gradient_low_rank = rank + score_args.query_gradient_accumulation_steps = 10 + scores_name += f"_qlr{rank}" + if args.use_ddp: + scores_name += "_ddp" + analyzer.compute_pairwise_scores( + score_args=score_args, + scores_name=scores_name, + factors_name=factors_name, + query_dataset=eval_dataset, + query_indices=list(range(10)), + train_dataset=train_dataset, + per_device_query_batch_size=args.query_batch_size, + per_device_train_batch_size=args.train_batch_size, + overwrite_output_dir=False, + ) + scores = analyzer.load_pairwise_scores(scores_name)["all_modules"] + logging.info(f"Scores shape: {scores.shape}") + + +if __name__ == "__main__": + main() diff --git a/examples/dailymail/inspect_examples.py b/examples/dailymail/inspect_examples.py new file mode 100644 index 0000000..2059b5c --- /dev/null +++ b/examples/dailymail/inspect_examples.py @@ -0,0 +1,37 @@ +import logging + +import torch + +from examples.dailymail.pipeline import get_dailymail_dataset, get_tokenizer +from kronfluence.analyzer import Analyzer + + +def main(): + logging.basicConfig(level=logging.INFO) + + # You might need to change the path. + strategy = "ekfac" + scores = Analyzer.load_file(f"influence_results/dailymail/scores_{strategy}_half/pairwise_scores.safetensors")[ + "all_modules" + ].to(dtype=torch.float32) + + eval_idx = 1 + train_dataset = get_dailymail_dataset( + split="eval_train", + ) + eval_dataset = get_dailymail_dataset( + split="valid", + ) + tokenizer = get_tokenizer() + print("Query Data Example:") + print(f"Input: {tokenizer.decode(eval_dataset[eval_idx]['input_ids'])}") + print(f"Label: {tokenizer.decode(eval_dataset[eval_idx]['labels'])}") + + top_idx = int(torch.argsort(scores[eval_idx], descending=True)[0]) + print("Top Influential Example:") + print(f"Input: {tokenizer.decode(train_dataset[top_idx]['input_ids'])}") + print(f"Label: {tokenizer.decode(train_dataset[top_idx]['labels'])}") + + +if __name__ == "__main__": + main() diff --git a/examples/dailymail/pipeline.py b/examples/dailymail/pipeline.py new file mode 100644 index 0000000..f221cdb --- /dev/null +++ b/examples/dailymail/pipeline.py @@ -0,0 +1,110 @@ +from typing import Any, List + +import torch.nn as nn +from datasets import Dataset, load_dataset +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer + +summarization_name_mapping = { + "amazon_reviews_multi": ("review_body", "review_title"), + "big_patent": ("description", "abstract"), + "cnn_dailymail": ("article", "highlights"), + "orange_sum": ("text", "summary"), + "pn_summary": ("article", "summary"), + "psc": ("extract_text", "summary_text"), + "samsum": ("dialogue", "summary"), + "thaisum": ("body", "summary"), + "xglue": ("news_body", "news_title"), + "xsum": ("document", "summary"), + "wiki_summary": ("article", "highlights"), + "multi_news": ("document", "summary"), +} + + +MODEL_NAME = "google-t5/t5-small" + + +def construct_t5() -> nn.Module: + config = AutoConfig.from_pretrained( + MODEL_NAME, + trust_remote_code=True, + ) + return AutoModelForSeq2SeqLM.from_pretrained( + MODEL_NAME, + from_tf=False, + config=config, + ignore_mismatched_sizes=False, + trust_remote_code=True, + ) + + +def get_tokenizer() -> Any: + return AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True) + + +def get_dailymail_dataset( + split: str, + indices: List[int] = None, +) -> Dataset: + raw_datasets = load_dataset("cnn_dailymail", "3.0.0") + + tokenizer = get_tokenizer() + column_names = raw_datasets["train"].column_names + dataset_columns = summarization_name_mapping.get("cnn_dailymail", None) + text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + + max_source_length = 1024 + max_target_length = 128 + padding = False + prefix = "summarize: " + + def preprocess_function(examples): + inputs = examples[text_column] + targets = examples[summary_column] + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True) + + labels = tokenizer( + text_target=targets, + max_length=max_target_length, + padding=padding, + truncation=True, + ) + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + if split == "train" or split == "eval_train": + train_dataset = raw_datasets["train"] + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=None, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on dataset.", + ) + ds = train_dataset + else: + valid_dataset = raw_datasets["validation"] + eval_dataset = valid_dataset.map( + preprocess_function, + batched=True, + num_proc=None, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on dataset.", + ) + ds = eval_dataset + + if indices is not None: + ds = ds.select(indices) + + return ds + + +if __name__ == "__main__": + from kronfluence import Analyzer + + model = construct_t5() + print(Analyzer.get_module_summary(model)) diff --git a/examples/dailymail/requirements.txt b/examples/dailymail/requirements.txt new file mode 100644 index 0000000..7ae6637 --- /dev/null +++ b/examples/dailymail/requirements.txt @@ -0,0 +1,7 @@ +sentencepiece!=0.1.92 +nltk +py7zr +rouge-score +transformers +evaluate +datasets diff --git a/examples/dailymail/train.py b/examples/dailymail/train.py new file mode 100644 index 0000000..686b3d4 --- /dev/null +++ b/examples/dailymail/train.py @@ -0,0 +1,228 @@ +import argparse +import logging +import os +import time +from typing import Any, Dict + +import evaluate +import nltk +import numpy as np +import torch +import torch.nn.functional as F +from accelerate.utils import send_to_device, set_seed +from filelock import FileLock +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils import data +from transformers import DataCollatorForSeq2Seq + +from examples.dailymail.pipeline import ( + construct_t5, + get_dailymail_dataset, + get_tokenizer, +) + +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train seq2seq models on CNN/DailyMail dataset.") + + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size for the training dataloader.", + ) + parser.add_argument( + "--eval_batch_size", + type=int, + default=32, + help="Batch size for the evaluation dataloader.", + ) + + parser.add_argument( + "--learning_rate", + type=float, + default=5e-05, + help="Fixed learning rate to train the model.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.01, + help="Weight decay to train the model.", + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=3, + help="Total number of epochs to train the model.", + ) + + parser.add_argument( + "--seed", + type=int, + default=1004, + help="A seed for reproducible training pipeline.", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default="./checkpoints", + help="A path to store the final checkpoint.", + ) + args = parser.parse_args() + + if args.checkpoint_dir is not None: + os.makedirs(args.checkpoint_dir, exist_ok=True) + + return args + + +def train( + dataset: data.Dataset, + tokenizer: Any, + batch_size: int, + num_train_epochs: int, + learning_rate: float, + weight_decay: float, +) -> nn.Module: + model = construct_t5().to(DEVICE) + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=-100, + pad_to_multiple_of=None, + ) + train_dataloader = data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + ) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + start_time = time.time() + model.train() + for epoch in range(num_train_epochs): + total_loss = 0.0 + for batch in train_dataloader: + optimizer.zero_grad(set_to_none=True) + batch = send_to_device(batch, device=DEVICE) + loss = model(**batch).loss + loss.backward() + optimizer.step() + total_loss += loss.detach().float() + logging.info(f"Epoch {epoch + 1} - Averaged Loss: {total_loss / len(dataset)}") + end_time = time.time() + elapsed_time = end_time - start_time + logging.info(f"Completed training in {elapsed_time:.2f} seconds.") + return model + + +def evaluate_model(model: nn.Module, tokenizer: Any, dataset: data.Dataset, batch_size: int) -> Dict[str, Any]: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=-100, + pad_to_multiple_of=None, + ) + dataloader = data.DataLoader( + dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=data_collator + ) + model.eval() + + def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence. + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + gen_kwargs = { + "max_length": 128, + } + metric = evaluate.load("rouge") + loss_fn = CrossEntropyLoss(ignore_index=-100, reduction="mean") + total_loss = 0.0 + for step, batch in enumerate(dataloader): + with torch.no_grad(): + logits = model( + input_ids=batch["input_ids"].to(device=DEVICE), + attention_mask=batch["attention_mask"].to(device=DEVICE), + decoder_input_ids=batch["decoder_input_ids"].to(device=DEVICE), + ).logits + labels = batch["labels"].to(device=DEVICE) + loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) + total_loss += loss.detach().float().item() + + labels = labels.cpu().numpy() + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + generated_tokens = model.generate( + batch["input_ids"].to(device=DEVICE), + attention_mask=batch["attention_mask"].to(device=DEVICE), + **gen_kwargs, + ) + if isinstance(generated_tokens, tuple): + generated_tokens = generated_tokens[0] + decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + metric.add_batch( + predictions=decoded_preds, + references=decoded_labels, + ) + + result = metric.compute(use_stemmer=True) + result = {k: round(v * 100, 4) for k, v in result.items()} + result["loss"] = total_loss / len(dataloader) + return result + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger() + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = get_tokenizer() + train_dataset = get_dailymail_dataset(split="train") + model = train( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.train_batch_size, + num_train_epochs=args.num_train_epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + ) + + eval_train_dataset = get_dailymail_dataset(split="eval_train") + results = evaluate_model( + model=model, tokenizer=tokenizer, dataset=eval_train_dataset, batch_size=args.eval_batch_size + ) + logger.info(f"Train evaluation results: {results}") + + eval_dataset = get_dailymail_dataset(split="valid") + results = evaluate_model(model=model, tokenizer=tokenizer, dataset=eval_dataset, batch_size=args.eval_batch_size) + logger.info(f"Valid evaluation results: {results}") + + if args.checkpoint_dir is not None: + torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth")) + + +if __name__ == "__main__": + main() diff --git a/examples/glue/README.md b/examples/glue/README.md index 045eac4..c1b43a4 100644 --- a/examples/glue/README.md +++ b/examples/glue/README.md @@ -9,7 +9,7 @@ pip install -r requirements.txt ## Training -To fine-tune BERT on a specific dataset, run the following command (we are using the `SST2` dataset in this example): +To fine-tune BERT on a specific dataset, run the following command (we are using the SST2 dataset in this example): ```bash python train.py --dataset_name sst2 \ @@ -36,25 +36,25 @@ python analyze.py --dataset_name sst2 \ --factor_strategy ekfac ``` -On an A100 (80GB), it takes roughly 95 minutes to compute the pairwise scores for `SST2` (including computing EKFAC factors): +On an A100 (80GB), it takes roughly 90 minutes to compute the pairwise scores for SST2 (including computing EKFAC factors): ``` ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 5568.5 | 100 % | +| Total | - | 11 | 5088.0 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 2668.0 | 1 | 2668.0 | 47.913 | -| Fit Lambda | 2361.5 | 1 | 2361.5 | 42.408 | -| Fit Covariance | 483.63 | 1 | 483.63 | 8.685 | -| Perform Eigendecomposition | 26.307 | 1 | 26.307 | 0.47243 | -| Save Covariance | 11.445 | 1 | 11.445 | 0.20552 | -| Save Eigendecomposition | 10.959 | 1 | 10.959 | 0.1968 | -| Save Lambda | 3.0458 | 1 | 3.0458 | 0.054696 | -| Save Pairwise Score | 2.0978 | 1 | 2.0978 | 0.037671 | -| Load Covariance | 0.72168 | 1 | 0.72168 | 0.01296 | -| Load Eigendecomposition | 0.5194 | 1 | 0.5194 | 0.0093274 | -| Load All Factors | 0.25427 | 1 | 0.25427 | 0.0045661 | +| Fit Lambda | 2370.0 | 1 | 2370.0 | 46.581 | +| Compute Pairwise Score | 2222.4 | 1 | 2222.4 | 43.679 | +| Fit Covariance | 478.83 | 1 | 478.83 | 9.411 | +| Perform Eigendecomposition | 10.587 | 1 | 10.587 | 0.20808 | +| Save Eigendecomposition | 2.5419 | 1 | 2.5419 | 0.049958 | +| Save Covariance | 2.3878 | 1 | 2.3878 | 0.046931 | +| Save Lambda | 0.66905 | 1 | 0.66905 | 0.01315 | +| Save Pairwise Score | 0.51374 | 1 | 0.51374 | 0.010097 | +| Load All Factors | 0.01321 | 1 | 0.01321 | 0.00025963 | +| Load Covariance | 0.0081149 | 1 | 0.0081149 | 0.00015949 | +| Load Eigendecomposition | 0.0079874 | 1 | 0.0079874 | 0.00015699 | ---------------------------------------------------------------------------------------------------------------------------------- ``` @@ -69,32 +69,32 @@ python analyze.py --dataset_name sst2 \ --use_half_precision ``` -This reduces computation time to about 30 minutes on an A100 (80GB) GPU. +This reduces computation time to about 20 minutes on an A100 (80GB) GPU. ``` ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 1832.0 | 100 % | +| Total | - | 11 | 1222.4 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 1143.2 | 1 | 1143.2 | 62.4 | -| Fit Lambda | 555.97 | 1 | 555.97 | 30.348 | -| Fit Covariance | 99.467 | 1 | 99.467 | 5.4294 | -| Perform Eigendecomposition | 18.566 | 1 | 18.566 | 1.0134 | -| Save Covariance | 5.4877 | 1 | 5.4877 | 0.29954 | -| Save Eigendecomposition | 5.3713 | 1 | 5.3713 | 0.29319 | -| Save Lambda | 1.5586 | 1 | 1.5586 | 0.085078 | -| Save Pairwise Score | 1.0651 | 1 | 1.0651 | 0.05814 | -| Load Eigendecomposition | 0.54052 | 1 | 0.54052 | 0.029504 | -| Load Covariance | 0.53759 | 1 | 0.53759 | 0.029345 | -| Load All Factors | 0.26048 | 1 | 0.26048 | 0.014218 | +| Compute Pairwise Score | 582.08 | 1 | 582.08 | 47.617 | +| Fit Lambda | 543.55 | 1 | 543.55 | 44.465 | +| Fit Covariance | 83.877 | 1 | 83.877 | 6.8616 | +| Perform Eigendecomposition | 9.4054 | 1 | 9.4054 | 0.76942 | +| Save Eigendecomposition | 1.516 | 1 | 1.516 | 0.12401 | +| Save Covariance | 1.434 | 1 | 1.434 | 0.11731 | +| Save Lambda | 0.28022 | 1 | 0.28022 | 0.022924 | +| Save Pairwise Score | 0.24123 | 1 | 0.24123 | 0.019734 | +| Load All Factors | 0.01241 | 1 | 0.01241 | 0.0010152 | +| Load Covariance | 0.0080553 | 1 | 0.0080553 | 0.00065897 | +| Load Eigendecomposition | 0.0077278 | 1 | 0.0077278 | 0.00063218 | ---------------------------------------------------------------------------------------------------------------------------------- ``` ## Counterfactual Evaluation -Evaluate the impact of removing top positively influential training examples on query misclassification. -First, compute pairwise influence scores for the `RTE` dataset: +Let's evaluate the impact of removing top positively influential training examples on query misclassification. +First, compute pairwise influence scores for the `RTE` dataset (the below commands used a single A100 GPU): ```bash python train.py --dataset_name rte \ diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py index 6876984..90dd71f 100644 --- a/examples/glue/analyze.py +++ b/examples/glue/analyze.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import Dict, Optional +from typing import Dict import torch import torch.nn.functional as F @@ -126,7 +126,7 @@ def compute_measurement( margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() - def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: + def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: return batch["attention_mask"] @@ -188,8 +188,8 @@ def main(): scores_name += "_half" rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None if rank is not None: - score_args.query_gradient_rank = rank - score_args.num_query_gradient_accumulations = 10 + score_args.query_gradient_low_rank = rank + score_args.query_gradient_accumulation_steps = 10 scores_name += f"_qlr{rank}" analyzer.compute_pairwise_scores( score_args=score_args, diff --git a/examples/glue/half_precision_analysis.py b/examples/glue/half_precision_analysis.py new file mode 100644 index 0000000..288c39c --- /dev/null +++ b/examples/glue/half_precision_analysis.py @@ -0,0 +1,40 @@ +import logging + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import spearmanr +from tueplots import markers + +from kronfluence.analyzer import Analyzer + + +def main(): + logging.basicConfig(level=logging.INFO) + + # Load the scores. You might need to modify the path. + scores = Analyzer.load_file("influence_results/sst2/scores_ekfac/pairwise_scores.safetensors")["all_modules"] + half_scores = Analyzer.load_file("influence_results/sst2/scores_ekfac_half/pairwise_scores.safetensors")[ + "all_modules" + ].float() + + plt.rcParams.update({"figure.dpi": 150}) + plt.rcParams.update(markers.with_edge()) + plt.rcParams["axes.axisbelow"] = True + + # Only plot first 6000 points to avoid clutter. + idx = 79 + plt.scatter(half_scores[idx][:6000], scores[idx][:6000], edgecolor="k") + plt.grid() + plt.xlabel("bfloat16") + plt.ylabel("float32") + plt.show() + + # Compute the averaged spearman correlation. + all_corr = [] + for i in range(500): + all_corr.append(spearmanr(scores[i], half_scores[i])[0]) + logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}") + + +if __name__ == "__main__": + main() diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md index e959c19..6595bb5 100644 --- a/examples/imagenet/README.md +++ b/examples/imagenet/README.md @@ -17,7 +17,7 @@ To compute pairwise influence scores on 1000 query data points using the `ekfac` python analyze.py --dataset_dir PATH_TO_IMAGENET \ --query_gradient_rank -1 \ --query_batch_size 100 \ - --train_batch_size 300 \ + --train_batch_size 256 \ --factor_strategy ekfac ``` @@ -53,17 +53,17 @@ python analyze.py --dataset_dir PATH_TO_IMAGENET \ --factor_strategy ekfac ``` -On an A100 (80GB) GPU, it takes roughly 3.5 hours to compute the pairwise scores with query batching (including computing EKFAC factors): +On an A100 (80GB) GPU, it takes roughly 3 hours to compute the pairwise scores with query batching: ``` ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 3 | 9896.8 | 100 % | +| Total | - | 3 | 7352.8 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 9849.7 | 1 | 9849.7 | 99.524 | -| Save Pairwise Score | 47.075 | 1 | 47.075 | 0.47566 | -| Load All Factors | 0.014463 | 1 | 0.014463 | 0.00014614 | +| Compute Pairwise Score | 7340.8 | 1 | 7340.8 | 99.836 | +| Save Pairwise Score | 12.026 | 1 | 12.026 | 0.16355 | +| Load All Factors | 0.0099941 | 1 | 0.0099941 | 0.00013592 | ---------------------------------------------------------------------------------------------------------------------------------- ``` @@ -91,19 +91,19 @@ This reduces computation time to about 85 minutes on an A100 (80GB) GPU: ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 4926.3 | 100 % | ----------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 4446.0 | 1 | 4446.0 | 90.249 | -| Fit Lambda | 255.45 | 1 | 255.45 | 5.1853 | -| Fit Covariance | 186.86 | 1 | 186.86 | 3.7931 | -| Save Pairwise Score | 23.205 | 1 | 23.205 | 0.47104 | -| Perform Eigendecomposition | 7.1356 | 1 | 7.1356 | 0.14485 | -| Save Eigendecomposition | 3.3045 | 1 | 3.3045 | 0.067079 | -| Save Covariance | 2.993 | 1 | 2.993 | 0.060756 | -| Save Lambda | 0.58278 | 1 | 0.58278 | 0.01183 | -| Load Eigendecomposition | 0.39114 | 1 | 0.39114 | 0.0079398 | -| Load Covariance | 0.27701 | 1 | 0.27701 | 0.005623 | -| Load All Factors | 0.1699 | 1 | 0.1699 | 0.0034489 | +| Total | - | 11 | 3023.7 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 2621.2 | 1 | 2621.2 | 86.688 | +| Fit Lambda | 232.3 | 1 | 232.3 | 7.6825 | +| Fit Covariance | 157.36 | 1 | 157.36 | 5.204 | +| Perform Eigendecomposition | 5.745 | 1 | 5.745 | 0.19 | +| Save Pairwise Score | 5.6676 | 1 | 5.6676 | 0.18744 | +| Save Covariance | 0.70454 | 1 | 0.70454 | 0.0233 | +| Save Eigendecomposition | 0.61539 | 1 | 0.61539 | 0.020352 | +| Save Lambda | 0.092784 | 1 | 0.092784 | 0.0030685 | +| Load Covariance | 0.013714 | 1 | 0.013714 | 0.00045354 | +| Load All Factors | 0.0088742 | 1 | 0.0088742 | 0.00029348 | +| Load Eigendecomposition | 0.0056237 | 1 | 0.0056237 | 0.00018599 | ---------------------------------------------------------------------------------------------------------------------------------- ``` diff --git a/examples/imagenet/analyze.py b/examples/imagenet/analyze.py index d77de2e..9dd14b1 100644 --- a/examples/imagenet/analyze.py +++ b/examples/imagenet/analyze.py @@ -101,6 +101,7 @@ def main(): per_device_batch_size=None, factor_args=factor_args, overwrite_output_dir=False, + initial_per_device_batch_size_attempt=512, ) # Compute pairwise scores. @@ -111,8 +112,8 @@ def main(): scores_name += "_half" rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None if rank is not None: - score_args.query_gradient_rank = rank - score_args.num_query_gradient_accumulations = 10 + score_args.query_gradient_low_rank = rank + score_args.query_gradient_accumulation_steps = 10 scores_name += f"_qlr{rank}" analyzer.compute_pairwise_scores( scores_name=scores_name, diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index f3c9733..33f9653 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -1,7 +1,6 @@ import argparse import logging import os -from typing import Tuple import torch @@ -133,8 +132,8 @@ def main(): scores_name += "_half" rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None if rank is not None: - score_args.query_gradient_rank = rank - score_args.num_query_gradient_accumulations = 10 + score_args.query_gradient_low_rank = rank + score_args.query_gradient_accumulation_steps = 10 scores_name += f"_qlr{rank}" analyzer.compute_pairwise_scores( scores_name=scores_name, diff --git a/examples/imagenet/query_batching_analysis.py b/examples/imagenet/query_batching_analysis.py index ae3e92b..6e82e98 100644 --- a/examples/imagenet/query_batching_analysis.py +++ b/examples/imagenet/query_batching_analysis.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import numpy as np -from scipy.stats import spearmanr +from scipy.stats import pearsonr, spearmanr from tueplots import markers from kronfluence.analyzer import Analyzer @@ -18,12 +18,15 @@ def main(): lr_scores = Analyzer.load_file("influence_results/imagenet/scores_ekfac_qlr32/pairwise_scores.safetensors")[ "all_modules" ] + # lr_scores = Analyzer.load_file("influence_results/imagenet/scores_ekfac_half_qlr32/pairwise_scores.safetensors")[ + # "all_modules" + # ].float() - # Only plot first 1000 points to avoid clutter. + # Only plot first 5000 points to avoid clutter. plt.rcParams.update({"figure.dpi": 150}) plt.rcParams.update(markers.with_edge()) plt.rcParams["axes.axisbelow"] = True - plt.scatter(lr_scores[0][:1000], full_scores[0][:1000], edgecolor="k") + plt.scatter(lr_scores[0][:5000], full_scores[0][:5000], edgecolor="k") plt.grid() plt.xlabel("Full Rank Score") plt.ylabel("Low Rank (32) Score") @@ -35,6 +38,12 @@ def main(): all_corr.append(spearmanr(full_scores[i], lr_scores[i])[0]) logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}") + # Compute the averaged pearson correlation. + all_corr = [] + for i in range(100): + all_corr.append(pearsonr(full_scores[i], lr_scores[i])[0]) + logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}") + if __name__ == "__main__": main() diff --git a/examples/openwebtext/README.md b/examples/openwebtext/README.md index aeca298..f10ece0 100644 --- a/examples/openwebtext/README.md +++ b/examples/openwebtext/README.md @@ -1,6 +1,29 @@ +# OpenWebText & Llama-3-8B Example + +This repository contains scripts for computing influence scores on the subset of OpenWebText dataset. The pipeline is motivated from [LoggIX repository](https://github.com/logix-project/logix/tree/main/examples/language_modeling). +Install the necessary packages: + ```bash -python analyze.py --query_batch_size 32 \ - --train_batch_size 64 \ - --checkpoint_dir ./checkpoints \ - --factor_strategy ekfac +pip install -r requirements.txt +``` + +We will use the pre-trained Meta-Llama-3-8B model [from HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B). + +## Computing EKFAC Factors + +To compute factors using the `ekfac` factorization strategy, run the following command which uses 4 A100 (80GB) GPUs: + +```bash +torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factors_name jul_11_2024 --factor_batch_size 4 +``` + +## Computing Influence Scores + +The `generate.py` folder contains a code to generate response of the Llama-3-8B model given certain prompt. +I saved some prompt and completition pair to the directory `data/data.json`. + +To compute influence scores on the generated prompt and compleition pair, run the following command: + +```bash +torchrun --standalone --nnodes=1 --nproc-per-node=4 compute_scores.py --train_batch_size 8 --query_gradient_rank 32 ``` \ No newline at end of file diff --git a/examples/openwebtext/analyze.py b/examples/openwebtext/analyze.py deleted file mode 100644 index 230518c..0000000 --- a/examples/openwebtext/analyze.py +++ /dev/null @@ -1,217 +0,0 @@ -import argparse -import logging -import os -from typing import Dict, List, Optional - -import torch -import torch.nn.functional as F -from torch import nn -from transformers import default_data_collator - -from examples.openwebtext.pipeline import ( - construct_llama3, - get_custom_dataset, - get_openwebtext_dataset, -) -from examples.wikitext.pipeline import construct_gpt2, get_wikitext_dataset -from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments -from kronfluence.task import Task -from kronfluence.utils.common.factor_arguments import ( - all_low_precision_factor_arguments, - extreme_reduce_memory_factor_arguments, -) -from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments -from kronfluence.utils.dataset import DataLoaderKwargs - -BATCH_TYPE = Dict[str, torch.Tensor] - - -if torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True - - -def parse_args(): - parser = argparse.ArgumentParser(description="Influence analysis on WikiText dataset.") - - parser.add_argument( - "--checkpoint_dir", - type=str, - default="./checkpoints", - help="A path that is storing the final checkpoint of the model.", - ) - - parser.add_argument( - "--factor_strategy", - type=str, - default="ekfac", - help="Strategy to compute influence factors.", - ) - parser.add_argument( - "--query_gradient_rank", - type=int, - default=-1, - help="Rank for the low-rank query gradient approximation.", - ) - parser.add_argument( - "--use_half_precision", - action="store_true", - default=False, - help="Whether to use half precision for computing factors and scores.", - ) - parser.add_argument( - "--use_compile", - action="store_true", - default=False, - help="Whether to use torch compile for computing factors and scores.", - ) - parser.add_argument( - "--query_batch_size", - type=int, - default=1, - help="Batch size for computing query gradients.", - ) - parser.add_argument( - "--train_batch_size", - type=int, - default=8, - help="Batch size for computing query gradients.", - ) - parser.add_argument( - "--profile", - action="store_true", - default=False, - help="Boolean flag to profile computations.", - ) - args = parser.parse_args() - - if args.checkpoint_dir is not None: - os.makedirs(args.checkpoint_dir, exist_ok=True) - - return args - - -class LanguageModelingTask(Task): - def compute_train_loss( - self, - batch: BATCH_TYPE, - model: nn.Module, - sample: bool = False, - ) -> torch.Tensor: - logits = model( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - ).logits - shift_logits = logits[..., :-1, :].contiguous() - - if not sample: - labels = batch["labels"] - shift_labels = labels[..., 1:].contiguous() - reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - summed_loss = F.cross_entropy( - reshaped_shift_logits, shift_labels.view(-1), reduction="sum", ignore_index=-100 - ) - else: - reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - with torch.no_grad(): - probs = torch.nn.functional.softmax(reshaped_shift_logits.detach(), dim=-1) - sampled_labels = torch.multinomial( - probs, - num_samples=1, - ).flatten() - summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, reduction="sum") - return summed_loss - - def compute_measurement( - self, - batch: BATCH_TYPE, - model: nn.Module, - ) -> torch.Tensor: - logits = model( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - ).logits - shift_logits = logits[..., :-1, :].contiguous() - labels = batch["labels"] - shift_labels = labels[..., 1:].contiguous().view(-1) - reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - return F.cross_entropy(reshaped_shift_logits, shift_labels, reduction="sum", ignore_index=-100) - - def tracked_modules(self) -> List[str]: - total_modules = [] - - for i in range(32): - # Only uses the MLP modules. - total_modules.append(f"model.layers.{i}.mlp.gate_proj") - total_modules.append(f"model.layers.{i}.mlp.up_proj") - total_modules.append(f"model.layers.{i}.mlp.down_proj") - - return total_modules - - def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: - return batch["attention_mask"] - - -def main(): - args = parse_args() - logging.basicConfig(level=logging.INFO) - - # Prepare the dataset. - train_dataset = get_openwebtext_dataset() - eval_dataset = get_custom_dataset() - - # Prepare the trained model. - model = construct_llama3().to(dtype=torch.bfloat16) - - # Define task and prepare model. - task = LanguageModelingTask() - model = prepare_model(model, task) - - analyzer = Analyzer( - analysis_name="openwebtext", - model=model, - task=task, - profile=args.profile, - ) - # Configure parameters for DataLoader. - dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator) - analyzer.set_dataloader_kwargs(dataloader_kwargs) - - # Compute influence factors. - factors_name = args.factor_strategy - factor_args = extreme_reduce_memory_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16) - analyzer.fit_all_factors( - factors_name=factors_name, - dataset=train_dataset, - per_device_batch_size=None, - factor_args=factor_args, - overwrite_output_dir=False, - initial_per_device_batch_size_attempt=64, - ) - - # Compute pairwise scores. - scores_name = factor_args.strategy - score_args = all_low_precision_score_arguments(dtype=torch.bfloat16) - - rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None - score_args.num_query_gradient_accumulations = 10 - if rank is not None: - score_args.query_gradient_rank = rank - scores_name += f"_qlr{rank}" - analyzer.compute_pairwise_scores( - scores_name=scores_name, - score_args=score_args, - factors_name=factors_name, - query_dataset=eval_dataset, - query_indices=list(range(min([len(eval_dataset), 2000]))), - train_dataset=train_dataset, - per_device_query_batch_size=args.query_batch_size, - per_device_train_batch_size=args.train_batch_size, - overwrite_output_dir=False, - ) - scores = analyzer.load_pairwise_scores(scores_name)["all_modules"] - logging.info(f"Scores shape: {scores.shape}") - - -if __name__ == "__main__": - main() diff --git a/examples/openwebtext/compute_scores.py b/examples/openwebtext/compute_scores.py new file mode 100644 index 0000000..61622d3 --- /dev/null +++ b/examples/openwebtext/compute_scores.py @@ -0,0 +1,116 @@ +import argparse +import logging +from datetime import timedelta +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator, InitProcessGroupKwargs +from torch import nn +from transformers import default_data_collator + +from examples.openwebtext.pipeline import ( + construct_llama3, + get_custom_dataset, + get_openwebtext_dataset, +) +from examples.openwebtext.task import LanguageModelingTask +from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.task import Task +from kronfluence.utils.common.factor_arguments import ( + extreme_reduce_memory_factor_arguments, +) +from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments, \ + extreme_reduce_memory_score_arguments +from kronfluence.utils.dataset import DataLoaderKwargs + +BATCH_TYPE = Dict[str, torch.Tensor] + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True + + +def parse_args(): + parser = argparse.ArgumentParser(description="Influence score computation on Openwebtext dataset.") + + parser.add_argument( + "--factor_strategy", + type=str, + default="ekfac", + help="Strategy to compute influence factors.", + ) + parser.add_argument( + "--query_gradient_rank", + type=int, + default=-1, + help="Rank for the low-rank query gradient approximation.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size for computing query gradients.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Boolean flag to profile computations.", + ) + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.INFO) + + # Prepare the dataset. + train_dataset = get_openwebtext_dataset() + eval_dataset = get_custom_dataset() + + # Prepare the trained model. + model = construct_llama3() + + # Define task and prepare model. + task = LanguageModelingTask() + model = prepare_model(model, task) + + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) # 1.5 hours. + accelerator = Accelerator(kwargs_handlers=[kwargs]) + model = accelerator.prepare_model(model) + + analyzer = Analyzer( + analysis_name="openwebtext", + model=model, + task=task, + profile=args.profile, + ) + # Configure parameters for DataLoader. + dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True) + analyzer.set_dataloader_kwargs(dataloader_kwargs) + + scores_name = args.factor_strategy + rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None + score_args = extreme_reduce_memory_score_arguments( + damping_factor=None, module_partitions=1, query_gradient_low_rank=rank, dtype=torch.bfloat16 + ) + # score_args.module_partitions = 2 + score_args.query_gradient_accumulation_steps = 10 + analyzer.compute_pairwise_scores( + scores_name=scores_name, + score_args=score_args, + factors_name=args.factor_strategy, + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=1, + per_device_train_batch_size=args.train_batch_size, + overwrite_output_dir=False, + ) + scores = analyzer.load_pairwise_scores(scores_name)["all_modules"] + logging.info(f"Scores shape: {scores.shape}") + + +if __name__ == "__main__": + main() diff --git a/examples/openwebtext/fit_factors.py b/examples/openwebtext/fit_factors.py new file mode 100644 index 0000000..740f349 --- /dev/null +++ b/examples/openwebtext/fit_factors.py @@ -0,0 +1,102 @@ +import argparse +import logging +from datetime import timedelta +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator, InitProcessGroupKwargs +from torch import nn +from transformers import default_data_collator + +from examples.openwebtext.pipeline import construct_llama3, get_openwebtext_dataset +from examples.openwebtext.task import LanguageModelingTask +from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.utils.common.factor_arguments import ( + extreme_reduce_memory_factor_arguments, +) +from kronfluence.utils.dataset import DataLoaderKwargs + +BATCH_TYPE = Dict[str, torch.Tensor] + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True + + +def parse_args(): + parser = argparse.ArgumentParser(description="Influence factor computation on Openwebtext dataset.") + + parser.add_argument( + "--factors_name", + type=str, + default="july_11", + help="Strategy to compute influence factors.", + ) + parser.add_argument( + "--factor_strategy", + type=str, + default="ekfac", + help="Strategy to compute influence factors.", + ) + parser.add_argument( + "--factor_batch_size", + type=int, + default=4, + help="Batch size for computing influence factors.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Boolean flag to profile computations.", + ) + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.INFO) + + # Prepare the dataset. + train_dataset = get_openwebtext_dataset() + + # Prepare the trained model. + model = construct_llama3() + + # Define task and prepare model. + task = LanguageModelingTask() + model = prepare_model(model, task) + + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) # 1.5 hours. + accelerator = Accelerator(kwargs_handlers=[kwargs]) + model = accelerator.prepare_model(model) + + analyzer = Analyzer( + analysis_name="openwebtext", + model=model, + task=task, + profile=args.profile, + ) + # Configure parameters for DataLoader. + dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True) + analyzer.set_dataloader_kwargs(dataloader_kwargs) + + factors_name = args.factors_name + factor_args = extreme_reduce_memory_factor_arguments( + strategy=args.factor_strategy, module_partitions=1, dtype=torch.bfloat16 + ) + factor_args.covariance_module_partitions = 2 + factor_args.lambda_module_partitions = 4 + analyzer.fit_all_factors( + factors_name=factors_name, + dataset=train_dataset, + per_device_batch_size=args.factor_batch_size, + factor_args=factor_args, + overwrite_output_dir=False, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/openwebtext/pipeline.py b/examples/openwebtext/pipeline.py index 1c4589a..fda0fe0 100644 --- a/examples/openwebtext/pipeline.py +++ b/examples/openwebtext/pipeline.py @@ -1,19 +1,22 @@ import copy from typing import List +import torch from datasets import load_dataset from torch import nn from torch.utils import data from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +MODEL_NAME = "meta-llama/Meta-Llama-3-8B" +MAX_LENGTH = 512 + def construct_llama3() -> nn.Module: config = AutoConfig.from_pretrained( - "meta-llama/Meta-Llama-3-8B", - trust_remote_code=True, + MODEL_NAME, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto" ) model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Meta-Llama-3-8B", + MODEL_NAME, from_tf=False, config=config, ignore_mismatched_sizes=False, @@ -25,17 +28,19 @@ def construct_llama3() -> nn.Module: def get_openwebtext_dataset( indices: List[int] = None, ) -> data.Dataset: - raw_datasets = load_dataset("stas/openwebtext-10k") - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True, trust_remote_code=True) - + raw_datasets = load_dataset("Elriggs/openwebtext-100k") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token column_names = raw_datasets["train"].column_names text_column_name = "text" if "text" in column_names else column_names[0] def tokenize_function(examples): - results = tokenizer(examples[text_column_name], truncation=True, padding=True, max_length=64) + results = tokenizer(examples[text_column_name], truncation=True, padding=True, max_length=MAX_LENGTH) results["labels"] = results["input_ids"].copy() + results["labels"] = [ + [-100 if token == tokenizer.pad_token_id else token for token in label] for label in results["labels"] + ] return results tokenized_datasets = raw_datasets.map( @@ -47,8 +52,7 @@ def tokenize_function(examples): desc="Running tokenizer on dataset", ) - train_dataset = tokenized_datasets["train"] - ds = train_dataset + ds = tokenized_datasets["train"] if indices is not None: ds = ds.select(indices) @@ -65,7 +69,7 @@ def get_custom_dataset( "num_proc": 4, } raw_datasets = load_dataset(**data_kwargs)["train"] - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True) def tokenize_function(examples): data_dict = {} diff --git a/examples/openwebtext/requirements.txt b/examples/openwebtext/requirements.txt new file mode 100644 index 0000000..462542f --- /dev/null +++ b/examples/openwebtext/requirements.txt @@ -0,0 +1,2 @@ +transformers +datasets diff --git a/examples/openwebtext/task.py b/examples/openwebtext/task.py new file mode 100644 index 0000000..343dba2 --- /dev/null +++ b/examples/openwebtext/task.py @@ -0,0 +1,71 @@ +from typing import Dict, List + +import torch +import torch.nn.functional as F +from torch import nn + +from kronfluence.task import Task + +BATCH_TYPE = Dict[str, torch.Tensor] + + +class LanguageModelingTask(Task): + def compute_train_loss( + self, + batch: BATCH_TYPE, + model: nn.Module, + sample: bool = False, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits + logits = logits[..., :-1, :].contiguous() + logits = logits.view(-1, logits.size(-1)) + + if not sample: + labels = batch["labels"] + shift_labels = labels[..., 1:].contiguous() + summed_loss = F.cross_entropy(logits, shift_labels.view(-1), reduction="sum", ignore_index=-100) + else: + with torch.no_grad(): + probs = torch.nn.functional.softmax(logits.detach(), dim=-1) + sampled_labels = torch.multinomial( + probs, + num_samples=1, + ).flatten() + summed_loss = F.cross_entropy(logits, sampled_labels, ignore_index=-100, reduction="sum") + return summed_loss + + def compute_measurement( + self, + batch: BATCH_TYPE, + model: nn.Module, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits + shift_labels = batch["labels"][..., 1:].contiguous().view(-1) + logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) + return F.cross_entropy(logits, shift_labels, ignore_index=-100, reduction="sum") + + def get_influence_tracked_modules(self) -> List[str]: + total_modules = [] + + # You can uncomment the following lines if you would like to compute influence also on attention layers. + # for i in range(32): + # total_modules.append(f"model.layers.{i}.self_attn.q_proj") + # total_modules.append(f"model.layers.{i}.self_attn.k_proj") + # total_modules.append(f"model.layers.{i}.self_attn.v_proj") + # total_modules.append(f"model.layers.{i}.self_attn.o_proj") + + for i in range(32): + total_modules.append(f"model.layers.{i}.mlp.gate_proj") + total_modules.append(f"model.layers.{i}.mlp.up_proj") + total_modules.append(f"model.layers.{i}.mlp.down_proj") + + return total_modules + + def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: + return batch["attention_mask"] diff --git a/examples/requirements.txt b/examples/requirements.txt index d7e148a..5fb9eed 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,3 +5,7 @@ tueplots transformers evaluate datasets +sentencepiece!=0.1.92 +nltk +py7zr +rouge-score \ No newline at end of file diff --git a/examples/swag/README.md b/examples/swag/README.md index 3863330..4ff8146 100644 --- a/examples/swag/README.md +++ b/examples/swag/README.md @@ -1,7 +1,7 @@ # SWAG & RoBERTa Example -This directory contains scripts for fine-tuning RoBERTa computing influence scores on the SWAG dataset. The pipeline is motivated from [this HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice). -To get started, please install the necessary packages: +This directory demonstrates fine-tuning RoBERTa on the SWAG dataset and computing influence scores. The implementation is inspired by [this HuggingFace example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice) and showcases how to define `post_process_per_sample_gradient`. +Install the required packages: ```bash pip install -r requirements.txt @@ -9,7 +9,7 @@ pip install -r requirements.txt ## Training -To fine-tune RoBERTa on SWAG, run the following command: +To fine-tune RoBERTa on the SWAG dataset, run the following command: ```bash python train.py --checkpoint_dir ./checkpoints \ @@ -21,61 +21,120 @@ python train.py --checkpoint_dir ./checkpoints \ --seed 1004 ``` -This will fine-tune the model using the specified hyperparameters and save the final checkpoint in the `./checkpoints` directory. +The final checkpoint will be saved in the `./checkpoints` directory. ## Computing Pairwise Influence Scores -To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command: +To calculate pairwise influence scores on 2000 query data points using `ekfac`, run: ```bash -python analyze.py --query_batch_size 64 \ +python analyze.py --factor_batch_size 128 \ + --query_batch_size 100 \ + --train_batch_size 64 \ + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac +``` + +Alternative options for `factor_strategy` include `identity`, `diagonal`, and `kfac`. +On an A100 (80GB), computing the pairwise scores (including EKFAC factors) takes approximately 10 hours: + +``` +---------------------------------------------------------------------------------------------------------------------------------- +| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | +---------------------------------------------------------------------------------------------------------------------------------- +| Total | - | 11 | 3.5124e+04 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 2.9384e+04 | 1 | 2.9384e+04 | 83.656 | +| Fit Lambda | 3578.7 | 1 | 3578.7 | 10.189 | +| Fit Covariance | 2143.9 | 1 | 2143.9 | 6.1036 | +| Perform Eigendecomposition | 10.213 | 1 | 10.213 | 0.029078 | +| Save Eigendecomposition | 3.4398 | 1 | 3.4398 | 0.0097933 | +| Save Covariance | 2.5179 | 1 | 2.5179 | 0.0071684 | +| Save Pairwise Score | 1.2982 | 1 | 1.2982 | 0.0036959 | +| Save Lambda | 0.68226 | 1 | 0.68226 | 0.0019424 | +| Load All Factors | 0.013627 | 1 | 0.013627 | 3.8797e-05 | +| Load Eigendecomposition | 0.0088496 | 1 | 0.0088496 | 2.5195e-05 | +| Load Covariance | 0.008222 | 1 | 0.008222 | 2.3408e-05 | +---------------------------------------------------------------------------------------------------------------------------------- +``` + +For more efficient computation, use half-precision: + +```bash +python analyze.py --factor_batch_size 128 \ + --query_batch_size 100 \ --train_batch_size 128 \ --use_half_precision \ --checkpoint_dir ./checkpoints \ --factor_strategy ekfac ``` -You can also use `identity`, `diagonal`, and `kfac` for `factor_strategy`. On an A6000 (48GB), it takes roughly 95 minutes to compute the pairwise scores (including computing EKFAC factors): +This reduces computation time to about 3 hours on an A100 (80GB) GPU. ``` - +---------------------------------------------------------------------------------------------------------------------------------- +| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | +---------------------------------------------------------------------------------------------------------------------------------- +| Total | - | 11 | 1.0935e+04 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 9576.4 | 1 | 9576.4 | 87.578 | +| Fit Lambda | 932.07 | 1 | 932.07 | 8.524 | +| Fit Covariance | 411.81 | 1 | 411.81 | 3.7661 | +| Perform Eigendecomposition | 10.623 | 1 | 10.623 | 0.097145 | +| Save Eigendecomposition | 1.4735 | 1 | 1.4735 | 0.013475 | +| Save Covariance | 1.2953 | 1 | 1.2953 | 0.011846 | +| Save Pairwise Score | 0.66271 | 1 | 0.66271 | 0.0060606 | +| Save Lambda | 0.34022 | 1 | 0.34022 | 0.0031114 | +| Load All Factors | 0.012041 | 1 | 0.012041 | 0.00011012 | +| Load Covariance | 0.0079526 | 1 | 0.0079526 | 7.2728e-05 | +| Load Eigendecomposition | 0.0076841 | 1 | 0.0076841 | 7.0273e-05 | +---------------------------------------------------------------------------------------------------------------------------------- ``` -For more efficient computation, use AMP half precision + query batching + DDP: +Query batching (low-rank approximation to the query gradient; see **Section 3.2.2** from the paper) can be used to compute influence scores with a larger query batch size: ```bash -torchrun --standalone --nnodes=1 --nproc-per-node=2 analyze.py --factor_batch_size 128 \ +python analyze.py --factor_batch_size 128 \ --query_batch_size 100 \ --train_batch_size 128 \ - --checkpoint_dir ./checkpoints \ - --factor_strategy ekfac \ --query_gradient_rank 32 \ --use_half_precision \ - --use_ddp + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac ``` -This reduces computation time to about 20 minutes on an A100 (80GB) GPU: +On an A100 (80GB) GPU, it takes roughly 1 hour to compute the pairwise scores with query batching: ``` - +---------------------------------------------------------------------------------------------------------------------------------- +| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | +---------------------------------------------------------------------------------------------------------------------------------- +| Total | - | 3 | 2007.9 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 2007.2 | 1 | 2007.2 | 99.966 | +| Save Pairwise Score | 0.66464 | 1 | 0.66464 | 0.033102 | +| Load All Factors | 0.012345 | 1 | 0.012345 | 0.00061484 | +---------------------------------------------------------------------------------------------------------------------------------- ``` ## Evaluating Linear Datamodeling Score The `evaluate_lds.py` script computes the [linear datamodeling score (LDS)](https://arxiv.org/abs/2303.14186). It measures the LDS obtained by retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). -We obtain `xx` LDS (we get `xx` LDS with the AMP half precision + query batching + DDP). - -The script can also print top influential sequences for a given query. +We obtain `0.33` LDS (`0.30` LDS with half precision and half precision + query batching). ``` -Query Example: - Sentence1: The west has preferred to focus on endangered animals, rather than endangered humans. African elephants are hunted down and stripped of tusks and hidden by poachers. Their numbers in Africa slumped from 1.2m to 600,000 in a decade until CITES - the Convention on International Trade in Endangered Species - banned the trade in ivory. - Sentence2: African elephants are endangered by ivory poachers. - Label: 0 - +Query Data Example: + Option 0: He looks disgusted and spits it out onto the plate.He slides both hands around the crack. + Option 1: He looks disgusted and spits it out onto the plate.He passes someone to the bald guy. + Option 2: He looks disgusted and spits it out onto the plate.He picks up a piece of bread. + Option 3: He looks disgusted and spits it out onto the plate.He walks into the kitchen. + Label: 3 + Top Influential Example: - Sentence1: The article also mentions the greater prevalence of obesity among two minority populations, African-Americans and Hispanic/Latino, but does not consider in its analysis of the increase in obesity the increase of these these populations as a proportion of the United States population. African-Americans and Hispanic/Latinos have a higher rates of obesity than White Americans, while Asian-Americans have a relatively low rate of obesity. Despite only representing one third of the U.S. population, African-Americans and Hispanic/Latinos represent about one half of the population growth. - Sentence2: African-Americans are a minority in the U.S. - Label: 0 -``` + Option 0: He lowers her hair back over the cut.He lies fully clothed, still gazing at her scooter. + Option 1: He lowers her hair back over the cut.He bangs her head against her headrest. + Option 2: He lowers her hair back over the cut.He goes to the kitchen. + Option 3: He lowers her hair back over the cut.He gives him a sidelong look. + Label: 2 +``` \ No newline at end of file diff --git a/examples/swag/analyze.py b/examples/swag/analyze.py index 4db9a48..d5fb4c5 100644 --- a/examples/swag/analyze.py +++ b/examples/swag/analyze.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import Dict, Optional +from typing import Dict import torch import torch.nn.functional as F @@ -93,6 +93,8 @@ def parse_args(): class MultipleChoiceTask(Task): + enable_post_process_per_sample_gradient = True + def compute_train_loss( self, batch: BATCH_TYPE, @@ -135,9 +137,15 @@ def compute_measurement( margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() - def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: + def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: return batch["attention_mask"] + def post_process_per_sample_gradient(self, module_name: str, gradient: torch.Tensor) -> torch.Tensor: + del module_name + total_batch_size = gradient.size(0) + true_batch_size = int(total_batch_size / 4) + return gradient.reshape(true_batch_size, 4, *gradient.size()[1:]).sum(dim=1) + def main(): args = parse_args() @@ -169,7 +177,6 @@ def main(): rank=WORLD_RANK, world_size=WORLD_SIZE, ) - print(model) analyzer = Analyzer( analysis_name="swag", @@ -184,7 +191,6 @@ def main(): # Compute influence factors. factors_name = args.factor_strategy factor_args = FactorArguments(strategy=args.factor_strategy) - # factor_args.lambda_iterative_aggregate = True if args.use_half_precision: factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16) factors_name += "_half" @@ -206,8 +212,8 @@ def main(): scores_name += "_half" rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None if rank is not None: - score_args.query_gradient_rank = rank - score_args.num_query_gradient_accumulations = 10 + score_args.query_gradient_low_rank = rank + score_args.query_gradient_accumulation_steps = 10 scores_name += f"_qlr{rank}" if args.use_ddp: scores_name += "_ddp" diff --git a/examples/swag/evaluate_lds.py b/examples/swag/evaluate_lds.py index d493307..6bb0196 100644 --- a/examples/swag/evaluate_lds.py +++ b/examples/swag/evaluate_lds.py @@ -4,8 +4,9 @@ import torch import tqdm from scipy.stats import spearmanr +from transformers import AutoTokenizer -from examples.glue.pipeline import get_glue_dataset +from examples.swag.pipeline import get_swag_dataset from kronfluence.analyzer import Analyzer @@ -13,7 +14,7 @@ def evaluate_correlations(scores: torch.Tensor) -> float: margins = torch.from_numpy(torch.load(open("files/margins.pt", "rb"))) masks = torch.from_numpy(torch.load(open("files/masks.pt", "rb"))).float() - val_indices = np.arange(277) + val_indices = np.arange(2000) preds = masks @ scores.T rs = [] @@ -39,25 +40,28 @@ def main(): logging.info(f"LDS: {np.mean(corr_mean)}") # We can also visualize the top influential sequences. - eval_idx = 79 - train_dataset = get_glue_dataset( - data_name="rte", + eval_idx = 1004 + tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", use_fast=True, trust_remote_code=True) + + train_dataset = get_swag_dataset( split="eval_train", ) - eval_dataset = get_glue_dataset( - data_name="rte", + eval_dataset = get_swag_dataset( split="valid", ) + print("Query Data Example:") - print(f"Sentence1: {eval_dataset[eval_idx]['sentence1']}") - print(f"Sentence2: {eval_dataset[eval_idx]['sentence2']}") - print(f"Label: {eval_dataset[eval_idx]['label']}") + for i in range(4): + text = tokenizer.decode(eval_dataset[eval_idx]["input_ids"][i]) + print(f"Option {i}: {text}") + print(f"Label: {eval_dataset[eval_idx]['labels']}") top_idx = int(torch.argsort(scores[eval_idx], descending=True)[0]) print("Top Influential Example:") - print(f"Sentence1: {train_dataset[top_idx]['sentence1']}") - print(f"Sentence2: {train_dataset[top_idx]['sentence2']}") - print(f"Label: {train_dataset[top_idx]['label']}") + for i in range(4): + text = tokenizer.decode(train_dataset[top_idx]["input_ids"][i]) + print(f"Option {i}: {text}") + print(f"Label: {train_dataset[top_idx]['labels']}") if __name__ == "__main__": diff --git a/examples/swag/files/margins.pt b/examples/swag/files/margins.pt new file mode 100644 index 0000000..1e815d5 Binary files /dev/null and b/examples/swag/files/margins.pt differ diff --git a/examples/swag/files/masks.pt b/examples/swag/files/masks.pt new file mode 100644 index 0000000..653a6eb Binary files /dev/null and b/examples/swag/files/masks.pt differ diff --git a/examples/swag/influence_analysis.py b/examples/swag/influence_analysis.py new file mode 100644 index 0000000..d33a79a --- /dev/null +++ b/examples/swag/influence_analysis.py @@ -0,0 +1,41 @@ +import logging + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import spearmanr +from tueplots import markers + +from kronfluence.analyzer import Analyzer + + +def main(): + logging.basicConfig(level=logging.INFO) + + scores1 = Analyzer.load_file("influence_results/swag/scores_ekfac/pairwise_scores.safetensors")[ + "all_modules" + ].float() + scores2 = Analyzer.load_file("influence_results/swag/scores_ekfac_half_qlr32/pairwise_scores.safetensors")[ + "all_modules" + ].float() + + plt.rcParams.update({"figure.dpi": 150}) + plt.rcParams.update(markers.with_edge()) + plt.rcParams["axes.axisbelow"] = True + + # Only plot first 6000 points to avoid clutter. + idx = 1 + plt.scatter(scores1[idx], scores2[idx], edgecolor="k") + plt.grid() + plt.xlabel("score1") + plt.ylabel("score2") + plt.show() + + # Compute the averaged spearman correlation. + all_corr = [] + for i in range(500): + all_corr.append(spearmanr(scores1[i], scores2[i])[0]) + logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}") + + +if __name__ == "__main__": + main() diff --git a/examples/swag/pipeline.py b/examples/swag/pipeline.py index 2130c22..aaa7225 100644 --- a/examples/swag/pipeline.py +++ b/examples/swag/pipeline.py @@ -1,10 +1,18 @@ +from dataclasses import dataclass from itertools import chain -from typing import Any, List +from typing import Any, List, Optional, Union +import torch from datasets import load_dataset from torch import nn from torch.utils.data import Dataset -from transformers import AutoConfig, AutoModelForMultipleChoice, AutoTokenizer +from transformers import ( + AutoConfig, + AutoModelForMultipleChoice, + AutoTokenizer, + PreTrainedTokenizerBase, +) +from transformers.utils import PaddingStrategy def construct_roberta() -> nn.Module: @@ -19,6 +27,38 @@ def construct_roberta() -> nn.Module: ) +@dataclass +class DataCollatorForMultipleChoice: + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + + def __call__(self, features): + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature.pop(label_name) for feature in features] + batch_size = len(features) + num_choices = len(features[0]["input_ids"]) + flattened_features = [ + [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features + ] + flattened_features = list(chain(*flattened_features)) + + batch = self.tokenizer.pad( + flattened_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + # Un-flatten + batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} + # Add back labels + batch["labels"] = torch.tensor(labels, dtype=torch.int64) + return batch + + def get_swag_dataset( split: str, indices: List[int] = None, @@ -67,7 +107,7 @@ def preprocess_function(examples: Any): remove_columns=raw_datasets["train"].column_names, ) - if split == "train" or split == "eval_train": + if split in ["train", "eval_train"]: dataset = processed_datasets["train"] dataset = dataset.select(list(range(73_536))) else: diff --git a/examples/uci/README.md b/examples/uci/README.md index 735863b..91e1f66 100644 --- a/examples/uci/README.md +++ b/examples/uci/README.md @@ -18,7 +18,7 @@ python train.py --dataset_name concrete \ --train_batch_size 32 \ --eval_batch_size 1024 \ --learning_rate 0.03 \ - --weight_decay 1e-5 \ + --weight_decay 1e-05 \ --num_train_epochs 20 \ --seed 1004 ``` @@ -27,7 +27,7 @@ This will train the model using the specified hyperparameters and save the train ## Computing Pairwise Influence Scores -To compute pairwise influence scores using the `ekfac` factorization strategy, run the following command: +To compute pairwise influence scores using the `ekfac` strategy, run the following command: ```bash python analyze.py --dataset_name concrete \ @@ -36,26 +36,25 @@ python analyze.py --dataset_name concrete \ --factor_strategy ekfac ``` -You can also use `identity`, `diagonal`, and `kfac` for `factor_strategy`. -To measure the wall-clock time of computing influence scores, you can enable the `profile` flag: +You can also use `identity`, `diagonal`, and `kfac` for `factor_strategy`. To measure the wall-clock time of computing influence scores, you can enable the `profile` flag: ``` ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 0.35452 | 100 % | +| Total | - | 11 | 0.28983 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 0.13146 | 1 | 0.13146 | 37.082 | -| Fit Lambda | 0.12437 | 1 | 0.12437 | 35.08 | -| Fit Covariance | 0.077605 | 1 | 0.077605 | 21.89 | -| Perform Eigendecomposition | 0.0066845 | 1 | 0.0066845 | 1.8855 | -| Save Covariance | 0.0056978 | 1 | 0.0056978 | 1.6072 | -| Save Eigendecomposition | 0.0047404 | 1 | 0.0047404 | 1.3371 | -| Load Covariance | 0.0012774 | 1 | 0.0012774 | 0.36031 | -| Save Pairwise Score | 0.00080004 | 1 | 0.00080004 | 0.22567 | -| Save Lambda | 0.00074937 | 1 | 0.00074937 | 0.21138 | -| Load All Factors | 0.00066267 | 1 | 0.00066267 | 0.18692 | -| Load Eigendecomposition | 0.00047504 | 1 | 0.00047504 | 0.13399 | +| Compute Pairwise Score | 0.15112 | 1 | 0.15112 | 52.141 | +| Fit Lambda | 0.094882 | 1 | 0.094882 | 32.737 | +| Fit Covariance | 0.031336 | 1 | 0.031336 | 10.812 | +| Perform Eigendecomposition | 0.006672 | 1 | 0.006672 | 2.302 | +| Save Eigendecomposition | 0.0013218 | 1 | 0.0013218 | 0.45607 | +| Save Covariance | 0.0013099 | 1 | 0.0013099 | 0.45196 | +| Load All Factors | 0.00073975 | 1 | 0.00073975 | 0.25523 | +| Save Lambda | 0.00073158 | 1 | 0.00073158 | 0.25242 | +| Load Covariance | 0.00062487 | 1 | 0.00062487 | 0.2156 | +| Save Pairwise Score | 0.00058254 | 1 | 0.00058254 | 0.20099 | +| Load Eigendecomposition | 0.00050846 | 1 | 0.00050846 | 0.17543 | ---------------------------------------------------------------------------------------------------------------------------------- ``` diff --git a/examples/uci/tutorial.ipynb b/examples/uci/tutorial.ipynb index da30342..7160ce8 100644 --- a/examples/uci/tutorial.ipynb +++ b/examples/uci/tutorial.ipynb @@ -216,46 +216,46 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 213.03batch/s, loss=0.913]\n", - "Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 455.38batch/s, loss=0.629]\n", - "Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 591.37batch/s, loss=0.429]\n", - "Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 540.31batch/s, loss=0.352]\n", - "Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 514.58batch/s, loss=0.29]\n", - "Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 502.35batch/s, loss=0.253]\n", - "Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 530.39batch/s, loss=0.222]\n", - "Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 542.57batch/s, loss=0.203]\n", - "Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 260.40batch/s, loss=0.2]\n", - "Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 474.28batch/s, loss=0.18]\n", - "Epoch 10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 544.02batch/s, loss=0.16]\n", - "Epoch 11: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 356.42batch/s, loss=0.157]\n", - "Epoch 12: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 565.94batch/s, loss=0.15]\n", - "Epoch 13: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 678.62batch/s, loss=0.132]\n", - "Epoch 14: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 605.49batch/s, loss=0.161]\n", - "Epoch 15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 611.18batch/s, loss=0.14]\n", - "Epoch 16: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 576.23batch/s, loss=0.121]\n", - "Epoch 17: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 679.23batch/s, loss=0.128]\n", - "Epoch 18: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 650.96batch/s, loss=0.13]\n", - "Epoch 19: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 583.58batch/s, loss=0.116]\n", - "Epoch 20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 630.26batch/s, loss=0.11]\n", - "Epoch 21: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 658.96batch/s, loss=0.119]\n", - "Epoch 22: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 608.30batch/s, loss=0.112]\n", - "Epoch 23: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 610.26batch/s, loss=0.1]\n", - "Epoch 24: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 576.88batch/s, loss=0.119]\n", - "Epoch 25: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 547.62batch/s, loss=0.0916]\n", - "Epoch 26: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 240.92batch/s, loss=0.0929]\n", - "Epoch 27: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 645.42batch/s, loss=0.0956]\n", - "Epoch 28: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 624.61batch/s, loss=0.0995]\n", - "Epoch 29: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 692.54batch/s, loss=0.093]\n", - "Epoch 30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 750.97batch/s, loss=0.0958]\n", - "Epoch 31: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 712.11batch/s, loss=0.0921]\n", - "Epoch 32: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 650.56batch/s, loss=0.0898]\n", - "Epoch 33: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 629.30batch/s, loss=0.0867]\n", - "Epoch 34: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 577.89batch/s, loss=0.0854]\n", - "Epoch 35: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 410.50batch/s, loss=0.0799]\n", - "Epoch 36: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 522.48batch/s, loss=0.0821]\n", - "Epoch 37: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 309.45batch/s, loss=0.0751]\n", - "Epoch 38: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 242.70batch/s, loss=0.079]\n", - "Epoch 39: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 546.25batch/s, loss=0.0838]\n" + "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 389.33batch/s, loss=0.913]\n", + "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 336.95batch/s, loss=0.629]\n", + "Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 538.93batch/s, loss=0.429]\n", + "Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 376.75batch/s, loss=0.352]\n", + "Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 484.09batch/s, loss=0.29]\n", + "Epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 545.75batch/s, loss=0.253]\n", + "Epoch 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 363.86batch/s, loss=0.222]\n", + "Epoch 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 528.11batch/s, loss=0.203]\n", + "Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 578.23batch/s, loss=0.2]\n", + "Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 579.47batch/s, loss=0.18]\n", + "Epoch 10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 575.52batch/s, loss=0.16]\n", + "Epoch 11: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 580.09batch/s, loss=0.157]\n", + "Epoch 12: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 560.29batch/s, loss=0.15]\n", + "Epoch 13: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 371.33batch/s, loss=0.132]\n", + "Epoch 14: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 555.09batch/s, loss=0.161]\n", + "Epoch 15: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 524.96batch/s, loss=0.14]\n", + "Epoch 16: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 559.10batch/s, loss=0.121]\n", + "Epoch 17: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 549.66batch/s, loss=0.128]\n", + "Epoch 18: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 534.94batch/s, loss=0.13]\n", + "Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 360.81batch/s, loss=0.116]\n", + "Epoch 20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 583.09batch/s, loss=0.11]\n", + "Epoch 21: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 563.09batch/s, loss=0.119]\n", + "Epoch 22: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 612.03batch/s, loss=0.112]\n", + "Epoch 23: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 561.09batch/s, loss=0.1]\n", + "Epoch 24: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 396.01batch/s, loss=0.119]\n", + "Epoch 25: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 583.49batch/s, loss=0.0916]\n", + "Epoch 26: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 589.11batch/s, loss=0.0929]\n", + "Epoch 27: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 547.42batch/s, loss=0.0956]\n", + "Epoch 28: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 353.97batch/s, loss=0.0995]\n", + "Epoch 29: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 381.87batch/s, loss=0.093]\n", + "Epoch 30: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 509.93batch/s, loss=0.0958]\n", + "Epoch 31: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 543.89batch/s, loss=0.0921]\n", + "Epoch 32: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 609.94batch/s, loss=0.0898]\n", + "Epoch 33: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 377.88batch/s, loss=0.0867]\n", + "Epoch 34: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 372.61batch/s, loss=0.0854]\n", + "Epoch 35: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 352.03batch/s, loss=0.0799]\n", + "Epoch 36: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 336.11batch/s, loss=0.0821]\n", + "Epoch 37: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 437.86batch/s, loss=0.0751]\n", + "Epoch 38: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 449.35batch/s, loss=0.079]\n", + "Epoch 39: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 455.28batch/s, loss=0.0838]\n" ] } ], @@ -646,7 +646,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 16, @@ -863,7 +863,7 @@ "source": [ "from kronfluence import ScoreArguments\n", "\n", - "score_args = ScoreArguments(per_module_score=True)\n", + "score_args = ScoreArguments(compute_per_module_scores=True)\n", "analyzer.compute_pairwise_scores(\n", " score_args=score_args,\n", " scores_name=\"tutorial_per_module_score\",\n", @@ -924,7 +924,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 24, @@ -1215,7 +1215,7 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAJqCAYAAAA/sUHAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAABcSAAAXEgFnn9JSAABSZUlEQVR4nO3deXxU9b3/8fdksi8QEsAgyI6UsmMBAanIzyK4UQSsbbWCWqtWi0V7bXtFqXSxFRGstuW6wK0itSJg1V4rGGvVACKLyCJgQDZBQgiQdbLM9/dHmMlMMgmZ5EzOzOT1fDzyYOasn0m+mvPO93y/x2GMMQIAAAAAWC7G7gIAAAAAIFoRuAAAAAAgRAhcAAAAABAiBC4AAAAACBECFwAAAACECIELAAAAAEKEwAUAAAAAIULgAgAAAIAQIXABAAAAQIgQuAAAAAAgRAhcAAAAABAiBC4AAAAACJFYuwtAXVlZWSouLlbXrl3tLgUAAABo9Q4ePKiUlBQdO3Ys6H3p4QpDxcXFqqiosLsMFRcXq7i42O4yEOFoR7AC7QhWoB3BCrSj1qmioqLJP3d6uMKQp2drx44dttaRnZ0tSRo/frytdSCy0Y5gBdoRrEA7ghVoR61T//79m7xvq+3hKikp0erVq3Xrrbeqb9++SkxMVEpKigYPHqxHHnlERUVFdfaZO3euHA5HvV8///nPbfgkAAAAAMJVq+3heumll/TDH/5QktSvXz9de+21OnPmjHJycvTwww9r+fLleu+999SxY8c6+44ZM0a9e/eus/yiiy4Ked0AAAAAIkerDVxxcXG6/fbbde+996pfv37e5UePHtVVV12lLVu26N5779VLL71UZ9/bbrtNM2bMaMFqAQAAAESiVntL4c0336zFixf7hS1J6tSpk55++mlJ0sqVK1VeXm5HeQAAAACiQKsNXA0ZPHiwJMnlcik/P9/magAAAABEqlZ7S2FD9u3bJ6n6tsOMjIw667Ozs7V161aVlZWpS5cumjRpEuO3AAAAANRB4Apg0aJFkqSJEycqISGhzvoXXnjB7/2cOXM0depULV26VKmpqY0+T33TS+bm5iorK8s77ahdPM8asLsORDbaEaxAO4IVaEewAu2odSouLlZKSkqT9uWWwlr++c9/6rnnnlNcXJzmzZvnt653796aP3++duzYoaKiIh06dEjLli1T586d9eqrr+qmm26yqWoAAAAA4chhjDF2FxEuPvvsM40ePVoFBQVauHChZs2a1aj9jh49qoEDByo/P1/r1q3TxRdf3Kw6PD1fPPgY0YB2BCvQjmAF2hGsQDtqnZpzfU4P11lHjhzRxIkTVVBQoNmzZzc6bEnVMxvOnDlTkvTWW2+FqkQAAAAAEYbAJenkyZOaMGGCDhw4oJkzZ2r+/PlBH6NPnz6Sqnu7AAAAAEAicKmoqEiTJk3Szp07dd111+mZZ56Rw+EI+jgFBQWS1OTBdAAAAACiT6sOXC6XS5MnT9ZHH32kK664QsuXL5fT6Qz6OMYYrVq1SpI0bNgwq8sEAAAAEKFabeCqqqrSd7/7XWVnZ2vs2LFauXKl4uPj690+Ly9PTz/9tAoLC/2WFxUV6c4779SGDRuUlZWl6667LtSlAwAAAIgQrfY5XE899ZS3V6p9+/a66667Am43f/58tW/fXsXFxbr77rv185//XMOHD1enTp2Ul5enzZs3Kz8/X+np6VqxYoWSk5Nb8mMAAAAACGOtNnB5xlxJ8gavQObOnav27dsrMzNTDzzwgNavX689e/YoJydHTqdTPXr00IwZM/TTn/5UnTt3bonSAQAAAESIVhu45s6dq7lz5zZ6+7S0ND366KOhKwgAAABA1Gm1Y7gAAAAAINQIXAAAAAAQIgQuAAAAAAgRAhcAAAAAhEirnTQDAAAAQPja+1Whdh49I0nqnJ6kb3TPsLmipiFwAQAAAAg773x2XI/+32eSpKsGdorYwMUthQAAAADCjtuYmjcO++poLgIXAAAAgLDjm7diHJGbuAhcAAAAAMKO8UlcMZGbtwhcAAAAAMKPOzruKCRwAQAAAAg/3FIIAAAAACHiO2mGg8AFAAAAANYxfoHLxkKaicAFAAAAIOy4/W4ptK+O5iJwAQAAAAg7Rr6zFEZu4iJwAQAAAAg7frMUErgAAAAAwDpuxnABAAAAQIgwhgsAAAAAQsO3h4sxXAAAAABgIb8xXPaV0WwELgAAAABhxzBpBgAAAACEBrcUAgAAAECIGGYpBAAAAIDQ8LmjkFkKAQAAAMBK3FIIAAAAACHiO0thJE9TSOACAAAAEHaM34OPIzdxEbgAAAAAhB3jd0uhjYU0E4ELAAAAQNjxHcPliOB7CglcAAAAAMKO/y2F9tXRXAQuAAAAAGHHd9IMB2O4AAAAAMA6PPgYAAAAAELE/8HHkZu4CFwAAAAAwo6bWQoBAAAAIDQYwwUAAAAAIcIYLgAAAAAIEf9p4SM3cRG4AAAAAIQd/wcfRy4CFwAAAICwQw8XAAAAAISImzFcAAAAABAazFIIAAAAACHDc7gAAAAAICTcjOECAAAAgNBgDBcAAAAAhIhhDBcAAAAAhIZvDxdjuAAAAADAQn49XBH86GMCFwAAAICwQw8XAAAAAISIYZZCAAAAAAgNt/89hRGLwAUAAAAg7NDDBQAAAAAhYsQYLgAAAAAICTc9XAAAAAAQGr5juCI4bxG4AAAAAIQfvzkzIjhxEbgAAAAAhB3Dc7gAAAAAIDTcfrPCR27iInABAAAACDvMUggAAAAAIeJ217xmDBcAAAAAWIhZCiNcSUmJVq9erVtvvVV9+/ZVYmKiUlJSNHjwYD3yyCMqKiqqd9+lS5dqxIgRSk1NVUZGhq688krl5OS0YPUAAABA68FzuCLQSy+9pClTpuj555+X0+nUtddeq7Fjx2r//v16+OGHNXz4cB0/frzOfvfee69mzpyp7du36/LLL9eIESO0Zs0affOb39Tq1atb/oMAAAAAUcjNLIWRLS4uTrfffrt27typnTt36u9//7veeust7d69W0OHDtVnn32me++912+ftWvXatGiRcrMzNQnn3yi1atX66233tJ//vMfOZ1OzZw5U6dOnbLl8wAAAADRxG+WQgJX5Ln55pu1ePFi9evXz295p06d9PTTT0uSVq5cqfLycu+6BQsWSJIefPBB9enTx7t81KhRuuOOO3Tq1Ck999xzLVA9AAAAEN2M3xiuyE1crTZwNWTw4MGSJJfLpfz8fElSaWmpsrOzJUnTpk2rs49n2euvv95CVQIAAADRyydvMYYr2uzbt09S9W2HGRkZkqTdu3fL5XKpQ4cO6tKlS519hg0bJknatm1byxUKAAAARCm/WQptrKO5Yu0uIBwtWrRIkjRx4kQlJCRIkg4ePChJAcOWJKWkpCg9PV0FBQUqLCxUWlraOc/Tv3//gMtzc3OVlZXl7VGzS3FxsSTZXgciG+0IVqAdwQq0I1iBdtRySkrKvK8/+WSrKg87bauluLhYKSkpTdqXHq5a/vnPf+q5555TXFyc5s2b513umSY+OTm53n09P4TCwsLQFgkAAABEOZ/nHkd0aKGHy8dnn32mG2+8UcYYPfbYY96xXKGyY8eOgMs9PV/jx48P6fnPxfOXG7vrQGSjHcEKtCNYgXYEK9COWk7CR9lSaakkaeiwoRrdq71ttTS1d0uK7LBoqSNHjmjixIkqKCjQ7NmzNWvWLL/1qampkqofmFwfTxdzY24nBAAAANA4TJoR4U6ePKkJEybowIEDmjlzpubPn19nm65du0qSDh8+HPAYxcXFOnXqlNq1a0fgAgAAAJrJ/8HHBK6IVVRUpEmTJmnnzp267rrr9MwzzwSc579v375KSEhQXl6ejhw5Umf95s2bJUmDBg0Kec0AAABAtPObpTBy81brDlwul0uTJ0/WRx99pCuuuELLly+X0xl49pOkpCTvvbqvvPJKnfUrVqyQJF1zzTWhKxgAAABoJfyfw2VfHc3VagNXVVWVvvvd7yo7O1tjx47VypUrFR8f3+A+s2fPliT9+te/1t69e73L161bp8WLFys9PV233nprSOsGAAAAWgO3T+AKdAdapGi1sxQ+9dRTWrVqlSSpffv2uuuuuwJuN3/+fLVvXz0jyuWXX65Zs2Zp0aJFGjJkiL71rW+pvLxca9askTFGS5YsUXp6ekt9BAAAACBqGR58HNkKCgq8rz3BK5C5c+d6A5ckLVy4UEOGDNFTTz2lNWvWKD4+XpdffrnmzJmj0aNHh7RmAAAAoLXw6eCK6EkzWm3gmjt3rubOndukfWfMmKEZM2ZYWg8AAACAGsxSCAAAAAAh4nYzSyEAAAAAhITxmzTDvjqai8AFAAAAIOxEyxguAhcAAACAsMODjwEAAAAgRJg0AwAAAABCxHcMV0zk5i0CFwAAAIDw4xu4IvnRxwQuAAAAAGHH/5ZCGwtpJgIXAAAAgLDDLIUAAAAAECJMmgEAAAAAIcKDjwEAAAAgBIz/jBkELgAAAACwits/b3FLIQAAAABYxU0PFwAAAACEhqGHCwAAAABCgx4uAAAAAAiR2j1cDkVu4iJwAQAAAAgrRv6JKyZy8xaBCwAAAEB4YZZCAAAAAAgRxnABAAAAQIjUGcMVwYmLwAUAAAAgrBjDGC4AAAAACInaY7jo4QIAAAAAi9DDBQAAAAAhwiyFAAAAABAitXu4IhmBCwAAAEBYqR236OECAAAAAIvUfg4XY7gAAAAAwCLMUggAAAAAIcIshQAAAAAQIrXnzKCHCwAAAAAs4juGK4KzliQCFwAAAIAw4zuGK5JnKJQIXAAAAADCjO8YrkgevyURuAAAAACEGd8eLociO3ERuAAAAACEFcMYLgAAAAAIDd9JChnDBQAAAAAWYpZCAAAAAAgRt7vmNT1cAAAAAGAhI3q4AAAAACAkjN8shZGNwAUAAAAgrPiO4YqJ8AdxEbgAAAAAhBXfHi7GcAEAAACAhfx6uCI7bxG4AAAAAIQXt++DuCJ8FBeBCwAAAECYoYcLAAAAAELCzRguAAAAAAgNt5vncAEAAABASPgO4aKHCwAAAAAs5DtLYYTnLQIXAAAAgPDi+xwuAhcAAAAAWIgHHwMAAABAiPg/+JjABQAAAACW8RvDZWMdViBwAQAAAAgrvrMURngHF4ELAAAAQHgx3FIIAAAAAKHhdte8jvC8ReACAAAAEF548DEAAAAAhIj/g48JXBFr06ZNevTRR3XdddepS5cucjgcDf5A586d690m0NfPf/7zFqweAAAAiE4mimYpjLW7ADvNmzdPr732WtD7jRkzRr17966z/KKLLrKiLAAAAKBV83vwcYR3EbXqwDVq1CgNGjRIw4cP1/Dhw9W9e3e5XK5z7nfbbbdpxowZoS8QAAAAaIXcvoErwm8pbNWB64EHHrC7BAAAAAC18OBjAAAAAAgR/wcfR3bkatU9XE2VnZ2trVu3qqysTF26dNGkSZMYvwUAAABYxP/BxzYWYgGH8f00rVxiYqJcLpfq+5bMnTtXv/rVrwKumzp1qpYuXarU1NRGn69///4Bl+fm5iorK0vPP/98o48VCsXFxZKklJQUW+tAZKMdwQq0I1iBdgQr0I5axrovK/XM9gpJUu/0GP1yRIKt9dxyyy1KSUnRjh07gt6XWwqD0Lt3b82fP187duxQUVGRDh06pGXLlqlz58569dVXddNNN9ldIgAAABDx3D6vI7yDix4uX+fq4arP0aNHNXDgQOXn52vdunW6+OKLm1WHp+erKQnaStnZ2ZKk8ePH21oHIhvtCFagHcEKtCNYgXbUMlZsOqz7X/lEkjSiR4b+/qNRttbTnOtzergs0KlTJ82cOVOS9NZbb9lcDQAAABDZmKUQdfTp00dSdW8XAAAAgKbznzQjsiMXgcsiBQUFkhhACQAAADSX7wifmAhPLBFefngwxmjVqlWSpGHDhtlcDQAAABDZ3D6ByxHhNxUSuBopLy9PTz/9tAoLC/2WFxUV6c4779SGDRuUlZWl6667zqYKAQAAgOjgN4YrsvNW637w8Ztvvql58+Z535eXl0uS3yyDc+bM0VVXXaXi4mLdfffd+vnPf67hw4erU6dOysvL0+bNm5Wfn6/09HStWLFCycnJLf45AAAAgGjiO2d4pI/hatWBKy8vTxs2bKiz3HdZXl6eJCkzM1MPPPCA1q9frz179ignJ0dOp1M9evTQjBkz9NOf/lSdO3dusdoBAACAaGXo4YoOM2bM0IwZMxq1bVpamh599NHQFgQAAABAbjezFAIAAABASPjfUmhbGZYgcAEAAAAIK36zFNLDBQAAAADW8RvDZWMdViBwAQAAAAgrfg8+pocLAAAAAKzj+xyumAhPLBFePgAAAIBo4zeGK8JvKiRwAQAAAAgrRtHzHC4CFwAAAICwwhguAAAAAAgR3wcfR3jeInABAAAACC/+Dz6O7MRF4AIAAAAQVnxnKYzwvEXgAgAAABBemKUQAAAAAELF9zlckZ23CFwAAAAAwoubWQoBAAAAIDQYwwUAAAAAIeI7S6EjwhMXgQsAAABAWHEzhgsAAAAAQsP4zlJI4AIAAAAA6xi/Hq7ITlwELgAAAABhhVkKAQAAACBEfMdwRToCFwAAAICwYujhAgAAAIDQMMxSCAAAAACh4WaWQgAAAAAIDTezFAIAAABAaPhOmeEgcAEAAACAdXzHcEV43iJwAQAAAAgvbnfNaybNAAAAAAALGTGGCwAAAABCwm+WQvvKsASBCwAAAEBYcfuN4YrsyEXgAgAAABBefHq4uKUQAAAAACzkZpZCAAAAAAgNt18Pl311WIHABQAAACCs8OBjAAAAAAgRbikEAAAAgBAxhudwAQAAAEBIGMZwAQAAAEBouOnhAgAAAIDQ8J2lMNIRuAAAAACEFcODjwEAAAAgNPwnzbCxEAsQuAAAAACEFf9p4SM7cRG4AAAAAIQV3yFc9HABAAAAgIV8J82ghwsAAAAALGT8bim0sRALELgAAAAAhBVmKQQAAACAEHEzSyEAAAAAhIbfLIWK7MRF4AIAAAAQVozfpBn21WEFAhcAAACAsMIYLgAAAAAIETezFAIAAABAaPg/+DiyExeBCwAAAEBYoYcLAAAAAELE7TdpRmQnLgIXAAAAgLBieA4XAAAAAIQGsxQCAAAAQIj4P/g4shG4AAAAAIQVxnBFiU2bNunRRx/Vddddpy5dusjhcDTqB7p06VKNGDFCqampysjI0JVXXqmcnJwWqBgAAACIftE0hivW7gLsNG/ePL322mtB7XPvvfdq0aJFSkpK0oQJE1RWVqY1a9bo7bff1ooVK/Ttb387NMUCAAAArYSJoh6uVh24Ro0apUGDBmn48OEaPny4unfvLpfLVe/2a9eu1aJFi5SZmal169apT58+kqR169Zp3LhxmjlzpsaNG6f09PQW+gQAAABA9HHTwxUdHnjggaC2X7BggSTpwQcf9IYtqTq43XHHHXryySf13HPP6b777rO0TgAAAKA18engYpbC1qK0tFTZ2dmSpGnTptVZ71n2+uuvt2hdAAAAQLRx+91TaF8dViBwNdLu3bvlcrnUoUMHdenSpc76YcOGSZK2bdvW0qUBAAAAUSWansPVqm8pDMbBgwclKWDYkqSUlBSlp6eroKBAhYWFSktLO+cx+/fvH3B5bm6usrKyvD1qdikuLpYk2+tAZKMdwQq0I1iBdgQr0I5aRnFJmff1tk+2quqw08Zqqn/uKSkpTdqXHq5GKioqkiQlJyfXu43nh1BYWNgiNQEAAADRyLeHK9LRw2WjHTt2BFzu6fkaP358S5ZTh+cvN3bXgchGO4IVaEewAu0IVqAdtYyEjdlSaakk6aKhQzW6d3tb62lq75ZED1ejpaamSpJKSkrq3cbTxdyY2wkBAAAABBZNz+EicDVS165dJUmHDx8OuL64uFinTp1Su3btCFwAAABAMxifxBXheYvA1Vh9+/ZVQkKC8vLydOTIkTrrN2/eLEkaNGhQS5cGAAAARBV3FM1SSOBqpKSkJO+9uq+88kqd9StWrJAkXXPNNS1aFwAAABBtjM+jj2MiO28RuIIxe/ZsSdKvf/1r7d2717t83bp1Wrx4sdLT03XrrbfaVR4AAAAQFdxRNIarVc9S+Oabb2revHne9+Xl5ZKkiy++2Ltszpw5uuqqqyRJl19+uWbNmqVFixZpyJAh+ta3vqXy8nKtWbNGxhgtWbJE6enpLfoZAAAAgGgTTWO4WnXgysvL04YNG+os912Wl5fnt27hwoUaMmSInnrqKa1Zs0bx8fG6/PLLNWfOHI0ePTrkNQMAAADRzkTRGK4WC1y7du3Sjh07dMEFF2jkyJEtddoGzZgxQzNmzGix/QAAAACcm9swhiugl19+WePHj6/Ta/Szn/1MAwYM0He+8x2NHj1aU6ZMUVVVlZWnBgAAABAlKt2+gSuyE5elgevFF1/U1q1bNXToUO+ynJwcPf7440pLS9MNN9yg7t276x//+IeWLVtm5akBAAAARIEqt1GRq9L7vk1inI3VNJ+lgWv79u0aNGiQ4uPjvcteeOEFORwO/f3vf9eyZcu0ceNGpaam6tlnn7Xy1AAAAACiQGFZhd8YrrZJBC6v48ePq3Pnzn7L3n33XXXs2FETJkyQJGVkZOib3/ymPv/8cytPDQAAACAKnC6t8L52OKS0xMie58/SwJWUlKQzZ8543x89elR79uzRpZde6rddenq6CgoKrDw1AAAAgChwqqQmcLVJjFNMhM+aYWng6tmzp95//32dOnVKkrRs2TI5HA5v75bHsWPH1LFjRytPDQAAACAK+PZwpSdH9u2EksWBa8aMGTpz5owuuugiTZ06VQ8++KBSU1M1efJk7zYVFRX6+OOPdeGFF1p5agAAAABR4JRP4Ir08VuSxc/h+uEPf6h3331Xr776qvbv36+UlBQtXrxYmZmZ3m3eeOMNnT59WuPHj7fy1AAAAACiwGkCV/3i4uL0yiuv6IsvvlBeXp6+9rWvKS0tzW+bHj16aNWqVbr44outPDUAAACAKHC6pNz7Oj05voEtI0NIpvzo3r27unfvHnDdkCFDNGTIkFCcFgAAAECE++qMy/u6bVJkz1AohShwSdLJkye1adMmnThxQt26ddPo0aNDdSoAAAAAUWDTgZN6Yf0B7/v0pMjv4bJ00gxJysvL0/e+9z1lZWVp4sSJuvHGG/0ecvzss88qIyNDH3zwgdWnBgAAABChyiqqNHPJRr9l3dun2FSNdSwNXCdPntTo0aP1t7/9TQMGDNBdd90l4/uYaEnXXXedCgsLtWLFCitPDQAAACCCfXWmTGfKKr3vZ4zurqsHdbKxImtYGrh+85vfKDc3Vw899JA2b96sP/7xj3W2ycjI0KBBg/Tee+9ZeWoAAAAAEay80u19HeOQ5l7bX4lxThsrsoalgWv16tW68MILNXfu3Aa369Wrl44cOWLlqQEAAABEMJdP4EpJiPzJMjwsDVxHjhzR4MGDz7mdw+HQmTNnrDw1AAAAgAhWXlUTuBJiLZ9qwjaWfpI2bdro6NGj59wuNzdXHTp0sPLUAAAAACKYq6ImcMU7CVwBDR8+XBs3btT+/fvr3eaTTz7R1q1bNWbMGCtPDQAAACCC+fVwRcHYLQ9LA9c999wjl8ulKVOmaNeuXXXWf/7557rppptkjNHdd99t5akBAAAARDDfSTPo4arHxIkT9V//9V/atm2bBgwYoK997WtyOBz617/+pcGDB6tfv37avn27fvnLX+qSSy6x8tQAAAAAIpirssr7Op4xXPV79NFH9fLLL2vgwIHas2ePjDE6evSoPv30U/Xp00fLli3TvHnzrD4tAAAAgAjm28MVTZNmWDrf4pkzZ+RwODR9+nRNnz5deXl5+uKLL+R2u9WlSxd17tzZytMBAAAAiBJ+txQSuAJLT0/XyJEjtW7dOklShw4dmI0QAAAAwDm5ojRwWfpJ2rZtq549e1p5SAAAAACtQLTeUmjpJxk6dKhyc3OtPCQAAACAVsB3Wvj4WKaFD+iBBx7Qxo0btWLFCisPCwAAACDKuSp8ZimMomnhLR3DlZSUpNtuu03f+c53dPXVV+uaa65R165dlZiYGHD7b37zm1aeHgAAAECEcvk9+JjAFdC4cePkcDhkjNHrr7+uN954o8Htq6qqGlwPAAAAoHWI1gcfWxq4fvCDH8jhcFh5SAAAAACtQLROmmFp4Fq6dKmVhwMAAADQSriiNHBFzycBAAAAELF48HGQysvLtXXrVh05ckSS1LlzZw0ZMkTx8fGhOiUAAACACEXgaqSysjI99NBDWrx4sYqKivzWpaam6o477tCvfvWremcuBAAAAND6uCprJtRLiKLncFkauFwuly6//HKtW7dOkjRo0CB1795dDodDX3zxhT755BPNnz9fH374od555x0lJCRYeXoAAAAAEcr/wcfR08Nl6Sd54oknlJOTozFjxmjr1q3asmWLVq1apZUrV2rz5s365JNPNHbsWK1bt04LFy608tQAAAAAIli0Tgtv6SdZvny5OnTooDfffFMDBw6ss37AgAF644031L59ey1btszKUwMAAACIYK4oHcNl6Sf5/PPPNW7cOKWlpdW7TWpqqsaNG6fc3FwrTw0AAAAggkXrc7gs/SSxsbEqKSk553YlJSWKjQ3ZBIkAAAAAIkxphc+kGXHRM2mGpYFr4MCBys7O1r59++rdZv/+/crOztagQYOsPDUAAACACGWM0fEzLu/79qnR8ygpSwPXj370I5WWlmrcuHF67rnnVFpa6l1XWlqqJUuWaNy4cSorK9Mdd9xh5akBAAAARKhCV6VfD9d5baLnEVKW3td300036YMPPtAzzzyj22+/Xbfffrvat28vSTpx4oSk6vT6ox/9SN///vetPDUAAACACHX8TJn3dZzToYxkerjqtXjxYr3yyiu65JJLFBcXp7y8POXl5SkuLk5jx47VK6+8oj//+c9WnxYAAABAhDp2uuZ2wo5piYqJcdhYjbVCMnPF1KlTNXXqVFVWVio/P1+SlJmZyUQZAAAAAOr4yqeHq2ObBBsrsV5IE1BsbKzOO++8UJ4CAAAAQIT7qrAmcJ2XFj3jt6QQPIfrySef1Pbt2+vdZvv27XryyScbnMkQAAAAQOvhO0PheVHWw2Vp4Fq4cKHuu+8+tWnTpt5t0tLSNHv2bD355JNWnhoAAABAhCotr5mhMDUxuoYhWRq43nnnHQ0ZMkRdu3atd5tu3bppyJAhWrNmjZWnBgAAABChXJU+Dz2OjZ6HHksWB65Dhw6pZ8+e59yuV69eOnz4sJWnBgAAABChXJVu7+uEWMsnUreVpZ/G6XTK5XKdczuXy6WqqqpzbgcAAAAg+pX5PPQ4MY4ernpdeOGF+uCDD1RSUlLvNiUlJfrggw/Up08fK08NAAAAIELRw9VI06ZN08mTJ3XbbbepuLi4zvqSkhL98Ic/VEFBgaZNm2blqQEAAABEKL/AFRddgcvSKUDuuecevfDCC3r55Zf17rvv6rvf/a569eolScrNzdXy5ct1/Phx9e3bV/fee6+VpwYAAAAQoaJ50gxLA1dycrLWrl2rG2+8UdnZ2Vq4cKEcDockyRgjSbrsssv0wgsvKCUlxcpTAwAAAIhQZRU1PVyJ9HA1LCsrS2vXrtXGjRu1du1aHTp0SJJ0wQUX6PLLL9fw4cOtPiUAAACACEYPVxMMHz7cG67Kysp06tQptW/fPlSnAwAAABChXBVMmhFQYWGhPvroI+3Zsyfg+r179+qKK65QmzZt1LlzZ6WlpemGG27Q0aNHm3NaAAAAAFGEaeHr8de//lWjRo3SG2+8UWfdsWPHNHbsWK1du1aVlZUyxsjlcumVV17RZZddptLS0uacGgAAAECUYFr4erz33nuKiYnRjTfeWGfdvHnzdPz4cbVr106vvfaaCgsLtXnzZn3jG9/Q3r179ec//7k5pwYAAAAQBYwxtQIXPVxe27Zt08CBA9WxY0e/5W63W8uXL5fD4dBvf/tbXXPNNUpJSdGQIUO0cuVKxcbGavXq1c05NQAAAIAo4Bu2pOibpbBZn8bzTK3atm3bplOnTik2NlY33HCD37rOnTtr5MiR2rVrV3NODQAAACAK1A5c9HD5KC4uVkVFRZ3lmzZtkiQNGjRIbdq0qbO+S5cuOnPmTHNObatx48bJ4XDU+/XWW2/ZXSIAAAAQEXynhJekhCjr4WrWtPAdOnTQZ599Vmf5Bx98IIfDoREjRgTcr6ysTG3btm3OqcPC1KlTlZqaWmd5586dbagGAAAAiDy+U8JL0TdpRrMC18iRI7V69Wq9/vrruuaaayRJeXl5WrlypSRpwoQJAffbuXOnzj///OacOizMnz9f3bt3t7sMAAAAIGL59nDFx8bI4XDYWI31mhUf7777bhljNH36dP3gBz/Qfffdp+HDh6uwsFAXXHCBrrrqqjr77Nu3T3v27NHgwYObc2oAAAAAUaAsih96LDWzh+uyyy7T3Llz9atf/UovvviiHA6HjDFKSkrSkiVLFBtb9/B/+ctfJElXXHFFc04NAAAAIAr4TpoRbQ89lpoZuCTpoYce0tVXX62VK1cqLy9PF1xwgb7//e+rR48eAbdPSEjQrFmzNHHixOae2nbPPfec8vPzFRMTowsvvFDf/va31bVrV7vLAgAAACKGq6LmlsJo7OFyGGOM3UVEmnHjxum9996rszwuLk5z5szRnDlzGnWc/v37B1yem5urrKwsPf/8882qs7mKi4slSSkpKbbWgchGO4IVaEewAu0IVqAdWW9bXpUWbimXJHVKceg3YxJtrqiuW265RSkpKdqxY0fQ+0ZfhGwB3/zmN/XCCy8oNzdXJSUl2r17t37zm98oNjZWDz30kBYtWmR3iQAAAEBE8J2kMMpmhJdED5el3n77bV1xxRVKT0/Xl19+qaSkpCYdx9Pz1ZQEbaXs7GxJ0vjx422tA5GNdgQr0I5gBdoRrEA7st7qLUd078tbJUnDuqZr5V1j7C0ogOZcn0dhhrTPhAkT9I1vfEOnTp3Shg0b7C4HAAAACHu+08InxEbfpBkELov16dNHknT06FGbKwEAAADCn+8shQlReE9h9H0imxUUFEhiICUAAADQGGU+sxQm0sOFhuTl5en999+XJA0bNszmagAAAIDw56qghws+cnJytHr1alVVVfkt/+KLLzRlyhQVFxfr2muvVZcuXWyqEAAAAIgcfrcURuFzuJr94OPWZs+ePZo5c6aysrI0bNgwpaen68CBA9q0aZPKysrUv39/PfPMM3aXCQAAAEQEv1sK46LvlkICV5BGjhypO++8Uxs2bNDGjRtVUFCglJQUDRkyRNOnT9edd97Z5OngAQAAgNaGHi746devn/70pz/ZXQYAAAAQFZgWHgAAAABCxLeHK5FJMwAAAADAOr5juOjhAgAAAAAL8eBjAAAAAAgR3+dw8eBjAAAAALBQme+kGfRwAQAAAIB1fHu4onFa+Oj7RAAAAAAiht+08FH44GMCFwAAAADblNHDBQAAAACh4TdLIZNmAAAAAIB1/G4ppIcLAAAAAKzjNy08Y7gAAAAAwBput1F5FWO4AAAAAMByvs/gkngOFwAAAABY5ujpMu9rZ4xDGcnxNlYTGgQuAAAAALY4XFDqfZ3VJlGxzuiLJ9H3iQAAAABEhCM+gatLuyQbKwkdAhcAAAAAWxwuKPG+7kzgAgAAAADrHDnl28OVbGMloUPgAgAAAGAL3zFcXdLp4QIAAAAAyzCGCwAAAABCoLzSra8Ka6aFZwwXAAAAAFjk6OlSGVP92uGQOrUlcAEAAACAJXzHb52Xlqj42OiMJtH5qQAAAACEtdYwfksicAEAAACwwWGfKeGjdfyWROACAAAAYIOvTtdMmJHVNtHGSkKLwAUAAACgxfnOUHheGoELAAAAACxzjB4uAAAAAAiN44Uu7+vz2iTYWEloEbgAAAAAtChXZZVOFpd735/Xhh4uAAAAALDE5gOn/N53ZAwXAAAAADRfldto1t+2eN+3T42P2oceSwQuAAAAAC3o37uP+43fmnbRBTZWE3oELgAAAAAt4kxZhR5cvd37/qJu7fTAxL42VhR6BC4AAAAALWL1liM66jMd/H0TLpTD4bCxotAjcAEAAABoEZ8dK/S+vrzfeRrdq72N1bQMAhcAAACAFvH58SLv60svjP6wJRG4AAAAALSQXJ/A1atjqo2VtBwCFwAAAICQKyguV77Pw457E7gAAAAAwBqf59X0brVJjFWH1AQbq2k5BC4AAAAAIec7fqt3x9Son53Qg8AFAAAAIORqB67WgsAFAAAAIORy8whcAAAAAGA5Y4x2fnnG+57ABQAAAAAW2Xn0jI4XuiRJDoc0uEu6vQW1IAIXAAAAgJBau/O49/XQC9KV2UpmKJQIXAAAAABCyO02enXzYe/7/9fvPBuraXkELgAAAAAhs25fvg6eLJEkxTikKUM721xRyyJwAQAAAAiJY6fL9P1nN3jfX3phB52fnmRjRS2PwAUAAADAcm630e0vfOy37IYRXW2qxj6xdhcAAAAAIHpUVrn17915+uO7n2vb4dPe5d0zkzX+ax1trMweBC4AAAAAzfbZsTNatfmI/rn9qA6dLPVblxzv1Ct3jFacs/XdYEfgAgAAANAsr209ontf3ipj6q4b1TNTT39/mDJS4lu+sDBA4AIAAADQLM++v79O2Bp8QbpuHtVNU4Z2lsPhsKewMEDgAgAAANAk+08U65HXd+jTIzVjtWaM7q6bRnVTrw6pNlYWPghcAAAAAILmdhv96IWPteerIu+yvuelae61/W2sKvwQuAAAAAA0Wkl5pdbvy9eitXv9wlac06FHJhO2aiNwAQAAAGiQMUafHjmttbuOa+mH+3WmrNJvfYxDev+/xiurbaJNFYYvAhcAAAAAP4cLSvTVGZfOlFboVGm53tp+TP/a8VXAbQd2bqvnZnxDHdMIW4EQuJqgtLRUv/vd7/S3v/1NBw8eVEZGhiZOnKh58+apc+fOdpcHAAAAnFNFlVsnilw6XVqh/KJy7TtRrJzPT2jjFyd1oqj8nPtf3u883Ty6my7p3b5Vz0J4LgSuIJWVlWn8+PFav369OnXqpMmTJ+uLL77QkiVL9MYbb2j9+vXq2bOn3WUCAACgFahyGxWWVaiwrFKnSyt0pqxCJa4qVVS5VV7lVllFlU6XVuhkcYWOF5Ypr9ClvEKXjhe6dLL43KGqtm6ZyRrbp71uGdNDPZmFsFEIXEH69a9/rfXr12vUqFF6++23lZpa3dAWLFig++67T7fccov+/e9/21skAAAAwp4xRq5Kt86UVqigpEIni8t1qqRcJ0vKVeKqUmlFlcoqqlTkqtSZ0upQdeZsuDpTWqEzZZUqclWe+0RNlJESr/TkOLVNilNmSoImDsjSlKGd5YyhNysYBK4glJeX66mnnpIkPf30096wJUmzZ8/W//7v/+q9997Tpk2bdNFFF9lVJgAAAIJUWeVWWWV1j5Dr7L/VX265KqvkqqhetvnLSlW4pSPrD6iyyq2KKreKXFUqLa9URZVRpdutyirj97q8yq3KKrcq3UZlFVXKLyr39kZVVJlzFxdiSXFOtU2KU9fMZPVsn6JRvTI1rGs7XZCRbHdpUYHAFYQPP/xQp0+fVq9evTR06NA666dNm6Zt27bp9ddfJ3ABAICoYoxRlduoyhi53VLV2fdu7zKjSrdR5dmgUXX2fZXbyBif7WvtV+U+G07O3gLnee0+u4/bXWs/I+9rz/Lq99XLK6o8gad2eHKrzBOcfAKUZ5tKd5DBZ+f20HyjmyEtMVZtEuOUkuBUnDNGsc4YJcbGKD05TulJ8eqQlqCObRLUMS2h+nVaojqkJSgxzml36VGNwBWETz75RJI0bNiwgOs9y7dt29ZiNQFAsIypvvgxnteSzNnrDCNT89rUvPfd1rNO9az37CfvMes5ls95fM9dty6jI0VuSdKerwr9j2X8a/Y9nueY/ueo+R7UbFN7v8D7+C07+6LuvjXnVJ1zNr4uv8s+U88+ps4mDdTl83nrqcs0uE3Nuf2/9z7tos7P2P+95wDedX7tpZ5z1HrvqaXec9Rqg7W3O3CgQpK0vmxXvcfxnqOez+Bpk7XbX+22X/e/iXq+VwGOIb/31fu4z37P3GfDjpHxWaazoaP6AL7vPfuYs8uMzzrPe//QUhNqqnyCVJXPf7MIjfjYGGWmxKtdcrzapcQpLSFOiXExSoxzKiUhVmmJsUpLrL69zxOs2iR5/o1TakIst/qFKQJXEA4ePChJ6tKlS8D1nuUHDhxo1PH69w/8YLjc3FxlZWUpOzu7CVVap7i4WJJsr6O18fwirDI6+wux+rXxLqu7vmY7U2cft5HcqvmF7/b5Be+us+zsPqrZ17vM55e+22dd4Pc1FxnlFZUyRlq6463qz+epxXNsqc7Fim8Nkn8dkv+Fj+/vf9/1NRc1gbapHTL8t/duW/uctY5Xp57a29R7PBNwm3MeL4jPXLt+3+NGtJz/2F0BosEX++yuABHA6ZDinVJcjBQf41Dc2ddOuauXxTrldEjOGCnB6VCis/q10+FQbIy865wOz2uHnA4pNkZKi3coLc6hpFgpKdah5Njqc1XP9Fdx9iuA8rNfp6VSVX8FnqgdVisuLlZKSkqT9iVwBaGoqPpJ2snJge9n9fwQCgsLW6wm1K/SbVReJbmqJFeV8f5bXiVVuqsDiSegVLlrAkyV58tt/N7XbGd8tvHfvtItVZ5dXmmqa/Ceyy2foGT8zlc7tEQXz1/bqmytAgDQcpyO6gfher8kOfyWORSj6teOsyEk9uy/TodDzpiz61Szj8N7HIffcb3rzr53+hwjPuZsYHI6FBcjxTnPhqcYT5hynF1W8z7+7HYx9Uxz7vmDdEpKQst8MxHxCFw22rFjR8Dlnp6v8ePHt2Q5dXh6tuyuo7zSrZPF5TpR5FJ+cbnyi1zKLyrXieLqf/PPLi8qq1RJeZVKyitVWlEVFoNQATSO57rGoeqLKcfZZQ5VX3EZt1sOSU6n07u977byee97PM8x5buu1jl9t/Ktw/99zcZ1t6l7Uebdptb+jalLtY7fUO2B6jrXPg3V5f8tqbt/fT8f3/cOh//r2rUEOo48+9VzHAU8d+2ff/3H8Bz/yyNH5JDUuUvnWts7fD5jrePUOof/98J/e9X63tQ+Tr3nqPW9qv0ZYxyefx3ecOH7Psbh8Fkmn+2qt3FIiompdQzVbOOMObtfjONsUKnezhnjCT61l/m8djgUEyPvstiz66P5mUzhcn2EltXU3i2JwBUUz6yEJSUlAdd7/uKRlpbWYjVFMs9zI6pDVHVwOlHk0omi6nC15eApHS4o0Zmy0E13GqliHPL+0vP8cox1xvj9svP9qvmF7PDuG+gXs+cXp9+2tfaLcTjO/lKu/uUc492v5hez58sZIx06dEgOSd27d/Orw6Gac8ecvee8oYuHc13QyGdZwIuzWhdLqr2+1gXQ2S0CX2T6nKd6q/qP5Xsh6Let37kCXPjVOpZqH9t7vIaP5ft56hwrQM2ec9e9gK11ke3Zpp71vuepW6fv97JxF2Vc4MAK2dl5kqTx4wfaXAmA1oTAFYSuXbtKkg4fPhxwvWd5t27dWqymSFBe6daOL09r4xcn9faOr3S4oFSFZRUqLm+5W8ycMQ4lxzmVFO9UcrxT8bExio2JUayzOpDExcScDS3VgSXWGeMNLnHOmLP/ng02MTHe5XFOh/c4CbHV+8TFxijOGaN4Z8zZGYIcinf6Hv/sa59AFBvjc+yzNXjDk7MmVNX8NTFy/nKYnX1MkjR+/NdsrgQAAKDlEbiCMHjwYEnS5s2bA673LB80aFCL1RRujDEqq3DrUEGJ/u/TY9p3okj/3p2n06X1DP4MQlpCrDJT45WZmqDMlOp/26fGKzMlXhmpCWqTGKvk+Fglx9cEq6SzISveGRPVtzcAAAAgPBG4gjBmzBi1bdtWubm52rp1q4YMGeK3fsWKFZKka665xobq7PXlqVI9/I8dem93nsqr3EHt64xx+AWo9mcDVZd2SRrStZ06pFW/5xkRAAAAiDQEriDEx8fr7rvv1m9+8xv9+Mc/1ttvv+0dQLdgwQJt27ZNl156aat76PEL6w9ozuqGH/4X74zRoC5tNaJHhkb3aq92KXFqk1jzHIlIukUOAAAAaCwCV5AefPBBrV27Vjk5OerTp4/Gjh2rAwcOaMOGDerQoYOef/55u0tsMUdOlep3/9ylN7YdDbi+Z4cU/b+vdVS3zBRdPaiT0pPjW7hCAAAAwF4EriAlJibq3Xff1e9+9zu99NJLWr16tTIyMjRjxgzNmzev3ociR5Oyiio99q/demnDQZVW+E98cfs3e+q2sT3UNilOCbHcAggAAIDWjcDVBElJSXrkkUf0yCOP2F2KLX795k69uP6g37L+57fRn74/TN0ym/6MAgAAACDaELgQlPn/2l0nbP1+6kBNu+gCORmHBQAAAPghcKHRqtxG//OffX7L/n3/OHVvT68WAAAAEEiM3QUgcuQXufymfP/tlIGELQAAAKABBC402vFCl/d1akKsvjeyq43VAAAAAOGPwIVGO15Y5n3dMS3BxkoAAACAyEDgQqMdP1PTw9WBwAUAAACcE4ELjfaVT+Dq2CbRxkoAAACAyEDgQqNxSyEAAAAQHAIXGu3gyRLvawIXAAAAcG4ELjTKwfwSffD5Ce/7r5/fxsZqAAAAgMhA4EKjLM35QsZUv+7ZPkVjerW3tyAAAAAgAhC4cE6l5VX6+8eHvO9njumumBiHjRUBAAAAkYHAhXP6/HiRilyVkqTEuBhdN6yLzRUBAAAAkYHAhXM6VVrufd0xLVEpCbE2VgMAAABEDgIXzul0aYX3ddukOBsrAQAAACILgQvndKqkJnClJxO4AAAAgMYicOGc6OECAAAAmobAhXMicAEAAABNQ+DCOZ0qqZk0g1sKAQAAgMYjcOGc6OECAAAAmobAhXPymzQjKd7GSgAAAIDIQuDCOfn1cHFLIQAAANBoBC6c08nimjFc3FIIAAAANB6BCw06U250vNDlfd81I9nGagAAAIDIQuBCg/afdntfd0hLUKe2iTZWAwAAAEQWAhca5Bu4BndJl8PhsLEaAAAAILIQuNCgvFLjff31Tmk2VgIAAABEHgIXGlRVk7eUEOe0rxAAAAAgAhG40CDjE7i4mxAAAAAIDoELDXL7JC4niQsAAAAICoELDfLp4FIMgQsAAAAICoELDXL7JK6YGAIXAAAAEAwCFxrkO4aLvAUAAAAEh8CFBrl9XjtJXAAAAEBQCFxokP8shQQuAAAAIBgELjTIdwwXsxQCAAAAwSFwoUH+sxTaVgYAAAAQkQhcaBCzFAIAAABNR+BCg/xnKSRwAQAAAMEgcKFB/rMU2lYGAAAAEJG4hEaDjE8XFz1cAAAAQHAIXGiQm1sKAQAAgCYjcKFBPPgYAAAAaDoCFxrkP2mGfXUAAAAAkYjAhQZxSyEAAADQdAQuNIhp4QEAAICmI3ChQYzhAgAAAJqOwIUG+d5SSAcXAAAAEBwCFxrkk7fo4QIAAACCROBCg5g0AwAAAGg6AhcaxKQZAAAAQNMRuNAgN8/hAgAAAJqMwIUGGZ9RXIzhAgAAAIJD4EKD/GcpJHABAAAAwSBwoUG+Y7jo4QIAAACCQ+BCg3wffEzeAgAAAIJD4EKDmKUQAAAAaDoCFxrEc7gAAACApiNwoUE+eYsxXAAAAECQCFxoEM/hAgAAAJqOwBWEf//733I4HPV+XXzxxXaXaDm/MVwkLgAAACAosXYXEIl69eqlSy65JODyaOM/SyGBCwAAAAgGgasJLrnkEi1dutTuMlqE33O4CFwAAABAULilEPUyxvhNmkHeAgAAAIJD4EK9TK33zFIIAAAABIdbCptg7969+sUvfqH8/Hy1b99el1xyiSZOnKiYmOjKr+5aiYsxXAAAAEBwCFxNkJOTo5ycHL9lAwcO1Kuvvqo+ffo0+jj9+/cPuDw3N1dZWVnKzs5uVp3NVVRcIt9O0HU5H6ptAqELwSkuLpYk29szIhvtCFagHcEKtKPWqbi4WCkpKU3aN7q6ZEKsbdu2+tnPfqb169crPz9f+fn5euedd3TxxRfr008/1YQJE3T69Gm7y7RM7R4uohYAAAAQHIcxpvZQnag1ZcoU7dq1K6h9/vrXv2rEiBENblNVVaXLLrtM77//vn7729/qF7/4RXPK9PZ87dixo1nHaa43335HP84u877fMudbapcSb2NFiESevwCOHz/e5koQyWhHsALtCFagHbVOzbk+b1W3FO7fv1+7d+8Oap+SkpJzbuN0OvXAAw/o/fff17/+9a9mB65wUTuKM4YLAAAACE6rClxbt24N2bE9Y7eOHj0asnO0NHet91E2JwgAAAAQclxCW6SgoECSmjyYLhwxSyEAAADQPAQui7z66quSpGHDhtlciXVq31LIc7gAAACA4BC4grBw4UIdOnTIb5kxRosXL9YTTzwhh8OhO++806bqrFdnlkLyFgAAABCUVjWGq7kWLlyo+++/X8OGDVOPHj1UVlamTz/9VPv371dMTIyefPJJXXTRRXaXaZna01c6SVwAAABAUAhcQbjvvvv09ttva8eOHdq5c6cqKirUqVMn3XjjjfrJT36i4cOH212ipdy17ilkDBcAAAAQHAJXEO655x7dc889dpfRYupMC88YLgAAACAojOFCvXynhSdrAQAAAMEjcKFevj1czFAIAAAABI/AhXr5zlLoYPwWAAAAEDQCF+rlO4SLGQoBAACA4BG4UC83txQCAAAAzULgQr2M3y2F9tUBAAAARCoCF+rlO0shPVwAAABA8AhcqJdvDxcPPQYAAACCR+BCvdwELgAAAKBZCFyol/GZp5A7CgEAAIDgEbhQL2YpBAAAAJqHwIV6MYYLAAAAaB4CF+rlN4aLlgIAAAAEjcto1Mt3Wnh6uAAAAIDgEbhQL99bCp0ELgAAACBoBC7Uy/eWQvIWAAAAEDwCF+rle0shsxQCAAAAwSNwoV7MUggAAAA0D4EL9XITuAAAAIBmIXChXj55i2nhAQAAgCbgMhr1cvvcU8gshQAAAEDwCFyol/GbpZDABQAAAASLwIV6MUshAAAA0DwELtSryidxkbcAAACA4BG4UK+DhTWJK6ttko2VAAAAAJGJwIV67SmoCVwjemTYWAkAAAAQmQhcCKjYValDhTWzZowkcAEAAABBI3AhoIKScr/ncPVsn2JbLQAAAECkInAhoPLKmtsJY2McinXSVAAAAIBgcRWNgFw+gSshlmYCAAAANAVX0gjIL3DFOW2sBAAAAIhcBC4E5HtLYTy3EwIAAABNwpU0AnJVVnlfJ8TRTAAAAICm4EoaAbkqGMMFAAAANBdX0gjIdwxXPIELAAAAaBKupBFQeZXPLYWxTJoBAAAANAWBCwFxSyEAAADQfFxJIyBuKQQAAACajytpBFTOg48BAACAZuNKGgH5TQvPGC4AAACgSQhcCIhbCgEAAIDm40oaAXFLIQAAANB8XEkjIJdf4OKWQgAAAKApCFwIyG8MVxzNBAAAAGgKrqQRkN8YLifNBAAAAGgKrqQRkN8thfRwAQAAAE3ClTQCclUwhgsAAABoLgIXAiqvYlp4AAAAoLm4kkZArgrfBx/TTAAAAICm4EoaAbl4DhcAAADQbFxJIyACFwAAANB8XEkjoHLf53AxaQYAAADQJLF2F4DwNKBzWzkqSlVRZZSeHGd3OQAAAEBEInAhoEU3DFV2drYkaWjXdjZXAwAAAEQmbikEAAAAgBAhcAEAAABAiBC4AAAAACBECFwAAAAAECIELgAAAAAIkVYbuIqLi/XCCy/onnvu0ciRI5WQkCCHw6G5c+eec9/Dhw9r5syZOv/885WYmKgLL7xQDz/8sMrKykJfOAAAAICI0Wqnhd+7d69+8IMfBL3f559/rlGjRunEiRMaMGCAxo4dq48//liPPPKI3nnnHb3zzjtKSEgIQcUAAAAAIk2r7eFKS0vTrbfeqr/85S/atGmTHnnkkUbtN2PGDJ04cUI/+clP9Omnn+rll1/W7t27NWXKFH344Yf63e9+F+LKAQAAAESKVhu4evXqpWeffVY/+tGPNGzYMMXFxZ1zn48++kgffvihOnbsqD/84Q/e5bGxsfrzn/+suLg4Pfnkk6qsrAxl6QAAAAAiRKsNXE3x5ptvSpKuueaaOrcNnnfeeRo7dqwKCgr0wQcf2FEeAAAAgDBD4ArCJ598IkkaNmxYwPWe5du2bWuxmgAAAACEr1Y7aUZTHDx4UJLUpUuXgOs9yw8cONCo4/Xv3z/g8tzcXGVlZSk7O7sJVVqnuLhYkmyvA5GNdgQr0I5gBdoRrEA7ap2Ki4uVkpLSpH3p4QpCUVGRJCk5OTnges8PobCwsMVqAgAAABC+IraHa8qUKdq1a1dQ+/z1r3/ViBEjQlRR8Hbs2BFwuafna/z48S1ZTh2ev9zYXQciG+0IVqAdwQq0I1iBdtQ6NbV3S4rgwLV//37t3r07qH1KSkqadc7U1NQGj+PpYk5LS2vWeQAAAABEh4gNXFu3bm3xc3bt2lVbtmzR4cOHA673LO/WrVtLlgUAAAAgTDGGKwiDBw+WJG3evDnges/yQYMGtVhNAAAAAMIXgSsIV111lSTp9ddfl8vl8lv31Vdf6f3331e7du00ZswYO8oDAAAAEGYIXEEYMWKExowZo+PHj+uBBx7wLq+srNRdd92liooK/eQnP1FcXJyNVQIAAAAIFxE7hssKU6ZM0dGjRyVJX375pSTp2Wef1VtvvSVJ6tSpk1atWuW3z5IlSzRq1CgtWrRI2dnZ+vrXv66NGzdq3759Gj16tH7xi1+07IcAAAAAELZadeDasmVLnYcUHzlyREeOHJEUePKLPn36aMuWLXrooYf01ltvadWqVeratavmzJmjX/7yl0pISGiR2gEAAACEP4cxxthdBPylpaWpoqJCvXr1srUOzzT3zXnuAEA7ghVoR7AC7QhWoB21Trm5uYqLi1NhYWHQ+zKGKwylpKSExTiwY8eO6dixY3aXgQhHO4IVaEewAu0IVqAdtU5xcXFNDtn0cKFe/fv3lyTt2LHD5koQyWhHsALtCFagHcEKtCMEix4uAAAAAAgRAhcAAAAAhAiBCwAAAABChMAFAAAAACFC4AIAAACAEGGWQgAAAAAIEXq4AAAAACBECFwAAAAAECIELgAAAAAIEQIXAAAAAIQIgQsAAAAAQoTABQAAAAAhQuACAAAAgBAhcAEAAABAiBC44Ke0tFQPPfSQLrzwQiUmJur888/XLbfcoiNHjthdGlpYSUmJVq9erVtvvVV9+/ZVYmKiUlJSNHjwYD3yyCMqKiqqd9+lS5dqxIgRSk1NVUZGhq688krl5OQ0eL4PP/xQV155pTIyMpSamqoRI0bor3/9q9UfC2EgPz9fHTt2lMPhUO/evRvclraEQPLy8nT//ferb9++SkpKUkZGhoYNG6af/exnAbd//fXXdemll6pNmzZq06aNxo0bpzfffLPBc+zYsUPTp09Xhw4dlJSUpIEDB2rhwoVyu92h+EhoYRs3btT111+v888/X3FxcUpPT9fYsWO1ZMkSGWPqbF9VVaUnnnhCAwcOVFJSkjp06KDrr79eu3btavA8TWl7iEIGOKu0tNRcfPHFRpLp1KmTuf76682IESOMJNOhQweTm5trd4loQc8884yRZCSZfv36menTp5srrrjCpKWlGUnma1/7mvnqq6/q7Ddr1iwjySQlJZnJkyebK664wsTGxhqn02lWrVoV8FwrVqwwTqfTOBwOc+mll5qpU6ea9PR0I8ncd999If6kaGk333yzcTgcRpLp1atXvdvRlhDIxx9/bDIzM40k079/f/Od73zHTJo0yXTr1s04nc462z/xxBNGkomNjTUTJ040kydPNklJSUaS+eMf/xjwHDk5Od5tRowYYa6//nqTlZVlJJnp06cbt9sd6o+JEPL8f0KSGTZsmLn++uvNZZddZmJjY40k873vfc9v+6qqKjNlyhQjyaSnp5upU6eaSy+91DgcDpOcnGw2bNgQ8DxNaXuITgQueP33f/+3kWRGjRplCgsLvcsff/xxI8lceuml9hWHFrd06VJz++23m507d/ot//LLL83QoUONJPPd737Xb92aNWuMJJOZmWn27NnjXZ6Tk2Pi4+NNenq6KSgo8NsnPz/ftGnTxkgyr776qnf5sWPHTO/evY0k8+6771r++WCPtWvXGknm9ttvbzBw0ZYQyPHjx0379u1NcnKyee211+qsr33h+9lnnxmn02kSEhJMTk6Od/nu3btNZmamiY2NNXv37vXbp7y83PTo0cNIMgsWLPAuLywsNKNGjTKSzJIlS6z9YGgxFRUVpmPHjkaSWbZsmd+6nTt3moyMDCPJZGdne5d7/gDZp08fc+zYMe/yFStWGEmmd+/epqKiwu9YTWl7iF4ELhhjjHG5XKZt27ZGktm8eXOd9YMGDTKSzMcff2xDdQg3OTk5RpJJSEgwLpfLu3zSpElGknniiSfq7POTn/zESDLz58/3W/773//eSDKTJ0+us8/KlSuNJHP11Vdb/RFgg5KSEtOrVy/z9a9/3ezZs6fBwEVbQiB33nmnkWSefvrpoLafNWtWnXULFiwwkszdd9/tt/zll182kszgwYPr7LNp0yYjyQwYMKAp5SMMfPrpp0aS6du3b8D1nv+//P73v/cu69evn5EUsGf92muvNZLMihUr/JY3pe0hejGGC5KqxzycPn1avXr10tChQ+usnzZtmqTqe5GBwYMHS5JcLpfy8/MlVY//y87OllTTXnzV14Y897IH2ueqq65SYmKi1q5dq7KyMus+AGzxq1/9Svv27dNf/vIXxcXF1bsdbQmBlJaW6sUXX1RKSopmzpzZqH0aahNNaUfDhg1Tz549tX37dn3xxRfBlI8wkZCQ0KjtMjMzJUn79+/Xrl27lJSUpKuuuqrOdk1pR1xTtT4ELkiSPvnkE0nVv0wC8Szftm1bi9WE8LVv3z5JUlxcnDIyMiRJu3fvlsvlUocOHdSlS5c6+9TXhhpqe/Hx8RowYIDKysq0Z88eSz8DWta2bdv0+OOPa+bMmRo7dmyD29KWEMjHH3+swsJCDR06VElJSfq///s/zZ49W3fddZcWLlyoL7/80m/7U6dO6eDBg5IU8A+JF1xwgdq3b68DBw7ozJkz3uX8PoxuPXv2VK9evbR792699NJLfut27dqlF198Ue3atdOUKVMk1bSHAQMGBPxDUaD20NS2h+hF4IIkef/HEOjixnf5gQMHWqwmhK9FixZJkiZOnOj9a+G52lBKSorS09NVUFCgwsJCSdKZM2d0+vTpBvej7UU+t9ut2267Tenp6frDH/5wzu1pSwhk586dkqSOHTvq29/+tq688ko98cQT+vOf/6yf/vSn6t27t5YvX+7d3tOO2rVrp5SUlIDHDNQm+H0Y3ZxOp/73f/9X6enp+v73v6+LLrpIN9xwg8aPH69BgwapS5cueuedd7x/TGxKe2hq20P0InBBkrxTfCcnJwdc7/kfhufiBq3XP//5Tz333HOKi4vTvHnzvMvP1Yakuu3Id2p52l70+uMf/6iNGzfqscce896m0xDaEgIpKCiQJP3jH//QW2+9paefflrHjx/XF198ofvvv1+lpaW6+eabtXXrVklNa0eN2Y92FPnGjBmj9957Tz179tTmzZv18ssv691331VMTIy+9a1vqWfPnt5tm9Iemtr2EL0IXAAa7bPPPtONN94oY4wee+wx71guoD4HDx7Ugw8+qEsvvVQzZsywuxxEMM/zryorK/XII4/orrvuUocOHdStWzc99thjmj59uioqKvTYY4/ZXCnC3fLlyzVixAhdcMEF2rBhg4qKirRnzx7NmDFDjz/+uMaPHy+Xy2V3mYgiBC5IklJTUyVVP+w2kOLiYklSWlpai9WE8HLkyBFNnDhRBQUFmj17tmbNmuW3/lxtSKrbjjz7NLQfbS+y/fjHP1Z5ebn+8pe/NHof2hIC8f0ZB5o0w7Psvffe89s+mHbUmP1oR5Ft7969uvnmm9W+fXu98cYbGjFihFJSUtSnTx8tXrxYV199tTZv3qznn39eUtPaQ1PbHqIXgQuSpK5du0qSDh8+HHC9Z3m3bt1arCaEj5MnT2rChAk6cOCAZs6cqfnz59fZ5lxtqLi4WKdOnVK7du28v2DatGmjtm3bNrgfbS+yvfHGG0pOTtYdd9yhcePGeb9uuOEGSdVB3rPs2LFjkmhLCMzzc0tOTlaHDh3qrO/evbsk6fjx45Jq2lFBQYH34ra2QG2C34fR7W9/+5sqKio0ceJEvxDvcf3110uS/vOf/0hqWntoattD9CJwQVLNNN+bN28OuN6zfNCgQS1WE8JDUVGRJk2apJ07d+q6667TM888I4fDUWe7vn37KiEhQXl5eTpy5Eid9fW1oYbaXkVFhbZv367ExERdeOGFVnwc2ODUqVN67733/L42bNggSSorK/Mu80zXTltCIJ7Z3kpLSwPe7nXy5ElJNb0L6enp3gvfLVu21Nn+0KFDOnHihLp166Y2bdp4l/P7MLp5go7nDzS1eZZ7xgx62sP27dtVUVFRZ/tA7aGpbQ/Ri8AFSdUDSNu2bavc3FzvgGNfK1askCRdc801LVwZ7ORyuTR58mR99NFHuuKKK7R8+XI5nc6A2yYlJWn8+PGSpFdeeaXO+vrakOe5Jp71vt544w2VlZXp8ssvV2JiYrM+C+xhjAn4tX//fklSr169vMs8PRS0JQTStWtXDR48WMYY722DvjzLfKfhbqhNNKUdbdmyRfv27dOAAQO87RWRJSsrS1L1YwYC2bhxo6SaHtMePXqoX79+Ki0t9T5by1dT2hHXVK2QXU9cRvj57//+byPJjB492hQVFXmXP/7440aSufTSS+0rDi2usrLSTJkyxUgyY8eONcXFxefcZ82aNUaSyczMNHv27PEuz8nJMQkJCSY9Pd0UFBT47ZOfn2/atGljJJlXX33Vu/yrr74yvXv3NpLMu+++a9XHQpjYv3+/kWR69eoVcD1tCYEsW7bMSDIDBw40X375pXf5li1bTEZGhpFk/v73v3uXf/bZZ8bpdJqEhASzbt067/I9e/aYzMxMExsba/bu3et3jvLyctOjRw8jySxYsMC7vKioyIwaNcpIMkuWLAndh0RIbdq0yUgyksyf/vQnv3Xr1q0zKSkpRpJZs2aNd/kzzzxjJJk+ffqYr776yrv81VdfNZJM7969TUVFhd+xmtL2EL0IXPAqLS01I0eONJJMp06dzPXXX+9936FDB5Obm2t3iWhBCxcu9P5SmjJlirn55psDfuXl5fntN2vWLCPJJCcnm8mTJ5tJkyaZ2NhY43Q6zapVqwKea8WKFSYmJsY4HA5z2WWXmWnTppn09HQjycyePbsFPi1a2rkClzG0JQR28803G0kmPT3dXHnlleayyy4zCQkJRpL54Q9/WGf7BQsWGEkmNjbWTJo0yUyePNkkJSUZSebJJ58MeI4PP/zQu83IkSPN9ddfbzp16mQkmWnTphm32x3qj4kQuv/++72/3/r372+mT59uxowZY2JiYowkc/vtt/ttX1VV5f0DZLt27cy0adPMuHHjjMPhMElJSWb9+vUBz9OUtofoROCCn5KSEjNnzhzTq1cvEx8fb7KyssyMGTPMoUOH7C4NLezhhx/2/kJq6Gv//v119l2yZIm56KKLTHJysklPTzcTJ040H374YYPn++CDD8zEiRNNenq6SU5ONt/4xjfM0qVLQ/TpYLfGBC5jaEuoy+12m//5n//xtouUlBQzatSoBn/G//jHP8zYsWNNamqqSU1NNWPHjjWvv/56g+fZvn27mTp1qsnMzDSJiYmmf//+ZsGCBaaqqsrqjwQbrFy50kyYMMHb29SuXTtz2WWXmZdeeing9pWVlebxxx83/fv3N4mJiSYzM9NMmzbN7Nixo8HzNKXtIfo4jDHG6tsUAQAAAABMmgEAAAAAIUPgAgAAAIAQIXABAAAAQIgQuAAAAAAgRAhcAAAAABAiBC4AAAAACBECFwAAAACECIELAAAAAEKEwAUAAAAAIULgAgAAAIAQIXABAAAAQIgQuAAAAAAgRAhcAAAAABAiBC4AAAAACBECFwAAAACECIELAAAAAEKEwAUAAAAAIfL/Af6wOGuxtDCgAAAAAElFTkSuQmCC", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAJqCAYAAAA/sUHAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAABcSAAAXEgFnn9JSAABSZUlEQVR4nO3deXxU9b3/8fdksi8QEsAgyI6UsmMBAanIzyK4UQSsbbWCWqtWi0V7bXtFqXSxFRGstuW6wK0itSJg1V4rGGvVACKLyCJgQDZBQgiQdbLM9/dHmMlMMgmZ5EzOzOT1fDzyYOasn0m+mvPO93y/x2GMMQIAAAAAWC7G7gIAAAAAIFoRuAAAAAAgRAhcAAAAABAiBC4AAAAACBECFwAAAACECIELAAAAAEKEwAUAAAAAIULgAgAAAIAQIXABAAAAQIgQuAAAAAAgRAhcAAAAABAiBC4AAAAACJFYuwtAXVlZWSouLlbXrl3tLgUAAABo9Q4ePKiUlBQdO3Ys6H3p4QpDxcXFqqiosLsMFRcXq7i42O4yEOFoR7AC7QhWoB3BCrSj1qmioqLJP3d6uMKQp2drx44dttaRnZ0tSRo/frytdSCy0Y5gBdoRrEA7ghVoR61T//79m7xvq+3hKikp0erVq3Xrrbeqb9++SkxMVEpKigYPHqxHHnlERUVFdfaZO3euHA5HvV8///nPbfgkAAAAAMJVq+3heumll/TDH/5QktSvXz9de+21OnPmjHJycvTwww9r+fLleu+999SxY8c6+44ZM0a9e/eus/yiiy4Ked0AAAAAIkerDVxxcXG6/fbbde+996pfv37e5UePHtVVV12lLVu26N5779VLL71UZ9/bbrtNM2bMaMFqAQAAAESiVntL4c0336zFixf7hS1J6tSpk55++mlJ0sqVK1VeXm5HeQAAAACiQKsNXA0ZPHiwJMnlcik/P9/magAAAABEqlZ7S2FD9u3bJ6n6tsOMjIw667Ozs7V161aVlZWpS5cumjRpEuO3AAAAANRB4Apg0aJFkqSJEycqISGhzvoXXnjB7/2cOXM0depULV26VKmpqY0+T33TS+bm5iorK8s77ahdPM8asLsORDbaEaxAO4IVaEewAu2odSouLlZKSkqT9uWWwlr++c9/6rnnnlNcXJzmzZvnt653796aP3++duzYoaKiIh06dEjLli1T586d9eqrr+qmm26yqWoAAAAA4chhjDF2FxEuPvvsM40ePVoFBQVauHChZs2a1aj9jh49qoEDByo/P1/r1q3TxRdf3Kw6PD1fPPgY0YB2BCvQjmAF2hGsQDtqnZpzfU4P11lHjhzRxIkTVVBQoNmzZzc6bEnVMxvOnDlTkvTWW2+FqkQAAAAAEYbAJenkyZOaMGGCDhw4oJkzZ2r+/PlBH6NPnz6Sqnu7AAAAAEAicKmoqEiTJk3Szp07dd111+mZZ56Rw+EI+jgFBQWS1OTBdAAAAACiT6sOXC6XS5MnT9ZHH32kK664QsuXL5fT6Qz6OMYYrVq1SpI0bNgwq8sEAAAAEKFabeCqqqrSd7/7XWVnZ2vs2LFauXKl4uPj690+Ly9PTz/9tAoLC/2WFxUV6c4779SGDRuUlZWl6667LtSlAwAAAIgQrfY5XE899ZS3V6p9+/a66667Am43f/58tW/fXsXFxbr77rv185//XMOHD1enTp2Ul5enzZs3Kz8/X+np6VqxYoWSk5Nb8mMAAAAACGOtNnB5xlxJ8gavQObOnav27dsrMzNTDzzwgNavX689e/YoJydHTqdTPXr00IwZM/TTn/5UnTt3bonSAQAAAESIVhu45s6dq7lz5zZ6+7S0ND366KOhKwgAAABA1Gm1Y7gAAAAAINQIXAAAAAAQIgQuAAAAAAgRAhcAAAAAhEirnTQDAAAAQPja+1Whdh49I0nqnJ6kb3TPsLmipiFwAQAAAAg773x2XI/+32eSpKsGdorYwMUthQAAAADCjtuYmjcO++poLgIXAAAAgLDjm7diHJGbuAhcAAAAAMKO8UlcMZGbtwhcAAAAAMKPOzruKCRwAQAAAAg/3FIIAAAAACHiO2mGg8AFAAAAANYxfoHLxkKaicAFAAAAIOy4/W4ptK+O5iJwAQAAAAg7Rr6zFEZu4iJwAQAAAAg7frMUErgAAAAAwDpuxnABAAAAQIgwhgsAAAAAQsO3h4sxXAAAAABgIb8xXPaV0WwELgAAAABhxzBpBgAAAACEBrcUAgAAAECIGGYpBAAAAIDQ8LmjkFkKAQAAAMBK3FIIAAAAACHiO0thJE9TSOACAAAAEHaM34OPIzdxEbgAAAAAhB3jd0uhjYU0E4ELAAAAQNjxHcPliOB7CglcAAAAAMKO/y2F9tXRXAQuAAAAAGHHd9IMB2O4AAAAAMA6PPgYAAAAAELE/8HHkZu4CFwAAAAAwo6bWQoBAAAAIDQYwwUAAAAAIcIYLgAAAAAIEf9p4SM3cRG4AAAAAIQd/wcfRy4CFwAAAICwQw8XAAAAAISImzFcAAAAABAazFIIAAAAACHDc7gAAAAAICTcjOECAAAAgNBgDBcAAAAAhIhhDBcAAAAAhIZvDxdjuAAAAADAQn49XBH86GMCFwAAAICwQw8XAAAAAISIYZZCAAAAAAgNt/89hRGLwAUAAAAg7NDDBQAAAAAhYsQYLgAAAAAICTc9XAAAAAAQGr5juCI4bxG4AAAAAIQfvzkzIjhxEbgAAAAAhB3Dc7gAAAAAIDTcfrPCR27iInABAAAACDvMUggAAAAAIeJ217xmDBcAAAAAWIhZCiNcSUmJVq9erVtvvVV9+/ZVYmKiUlJSNHjwYD3yyCMqKiqqd9+lS5dqxIgRSk1NVUZGhq688krl5OS0YPUAAABA68FzuCLQSy+9pClTpuj555+X0+nUtddeq7Fjx2r//v16+OGHNXz4cB0/frzOfvfee69mzpyp7du36/LLL9eIESO0Zs0affOb39Tq1atb/oMAAAAAUcjNLIWRLS4uTrfffrt27typnTt36u9//7veeust7d69W0OHDtVnn32me++912+ftWvXatGiRcrMzNQnn3yi1atX66233tJ//vMfOZ1OzZw5U6dOnbLl8wAAAADRxG+WQgJX5Ln55pu1ePFi9evXz295p06d9PTTT0uSVq5cqfLycu+6BQsWSJIefPBB9enTx7t81KhRuuOOO3Tq1Ck999xzLVA9AAAAEN2M3xiuyE1crTZwNWTw4MGSJJfLpfz8fElSaWmpsrOzJUnTpk2rs49n2euvv95CVQIAAADRyydvMYYr2uzbt09S9W2HGRkZkqTdu3fL5XKpQ4cO6tKlS519hg0bJknatm1byxUKAAAARCm/WQptrKO5Yu0uIBwtWrRIkjRx4kQlJCRIkg4ePChJAcOWJKWkpCg9PV0FBQUqLCxUWlraOc/Tv3//gMtzc3OVlZXl7VGzS3FxsSTZXgciG+0IVqAdwQq0I1iBdtRySkrKvK8/+WSrKg87bauluLhYKSkpTdqXHq5a/vnPf+q5555TXFyc5s2b513umSY+OTm53n09P4TCwsLQFgkAAABEOZ/nHkd0aKGHy8dnn32mG2+8UcYYPfbYY96xXKGyY8eOgMs9PV/jx48P6fnPxfOXG7vrQGSjHcEKtCNYgXYEK9COWk7CR9lSaakkaeiwoRrdq71ttTS1d0uK7LBoqSNHjmjixIkqKCjQ7NmzNWvWLL/1qampkqofmFwfTxdzY24nBAAAANA4TJoR4U6ePKkJEybowIEDmjlzpubPn19nm65du0qSDh8+HPAYxcXFOnXqlNq1a0fgAgAAAJrJ/8HHBK6IVVRUpEmTJmnnzp267rrr9MwzzwSc579v375KSEhQXl6ejhw5Umf95s2bJUmDBg0Kec0AAABAtPObpTBy81brDlwul0uTJ0/WRx99pCuuuELLly+X0xl49pOkpCTvvbqvvPJKnfUrVqyQJF1zzTWhKxgAAABoJfyfw2VfHc3VagNXVVWVvvvd7yo7O1tjx47VypUrFR8f3+A+s2fPliT9+te/1t69e73L161bp8WLFys9PV233nprSOsGAAAAWgO3T+AKdAdapGi1sxQ+9dRTWrVqlSSpffv2uuuuuwJuN3/+fLVvXz0jyuWXX65Zs2Zp0aJFGjJkiL71rW+pvLxca9askTFGS5YsUXp6ekt9BAAAACBqGR58HNkKCgq8rz3BK5C5c+d6A5ckLVy4UEOGDNFTTz2lNWvWKD4+XpdffrnmzJmj0aNHh7RmAAAAoLXw6eCK6EkzWm3gmjt3rubOndukfWfMmKEZM2ZYWg8AAACAGsxSCAAAAAAh4nYzSyEAAAAAhITxmzTDvjqai8AFAAAAIOxEyxguAhcAAACAsMODjwEAAAAgRJg0AwAAAABCxHcMV0zk5i0CFwAAAIDw4xu4IvnRxwQuAAAAAGHH/5ZCGwtpJgIXAAAAgLDDLIUAAAAAECJMmgEAAAAAIcKDjwEAAAAgBIz/jBkELgAAAACwits/b3FLIQAAAABYxU0PFwAAAACEhqGHCwAAAABCgx4uAAAAAAiR2j1cDkVu4iJwAQAAAAgrRv6JKyZy8xaBCwAAAEB4YZZCAAAAAAgRxnABAAAAQIjUGcMVwYmLwAUAAAAgrBjDGC4AAAAACInaY7jo4QIAAAAAi9DDBQAAAAAhwiyFAAAAABAitXu4IhmBCwAAAEBYqR236OECAAAAAIvUfg4XY7gAAAAAwCLMUggAAAAAIcIshQAAAAAQIrXnzKCHCwAAAAAs4juGK4KzliQCFwAAAIAw4zuGK5JnKJQIXAAAAADCjO8YrkgevyURuAAAAACEGd8eLociO3ERuAAAAACEFcMYLgAAAAAIDd9JChnDBQAAAAAWYpZCAAAAAAgRt7vmNT1cAAAAAGAhI3q4AAAAACAkjN8shZGNwAUAAAAgrPiO4YqJ8AdxEbgAAAAAhBXfHi7GcAEAAACAhfx6uCI7bxG4AAAAAIQXt++DuCJ8FBeBCwAAAECYoYcLAAAAAELCzRguAAAAAAgNt5vncAEAAABASPgO4aKHCwAAAAAs5DtLYYTnLQIXAAAAgPDi+xwuAhcAAAAAWIgHHwMAAABAiPg/+JjABQAAAACW8RvDZWMdViBwAQAAAAgrvrMURngHF4ELAAAAQHgx3FIIAAAAAKHhdte8jvC8ReACAAAAEF548DEAAAAAhIj/g48JXBFr06ZNevTRR3XdddepS5cucjgcDf5A586d690m0NfPf/7zFqweAAAAiE4mimYpjLW7ADvNmzdPr732WtD7jRkzRr17966z/KKLLrKiLAAAAKBV83vwcYR3EbXqwDVq1CgNGjRIw4cP1/Dhw9W9e3e5XK5z7nfbbbdpxowZoS8QAAAAaIXcvoErwm8pbNWB64EHHrC7BAAAAAC18OBjAAAAAAgR/wcfR3bkatU9XE2VnZ2trVu3qqysTF26dNGkSZMYvwUAAABYxP/BxzYWYgGH8f00rVxiYqJcLpfq+5bMnTtXv/rVrwKumzp1qpYuXarU1NRGn69///4Bl+fm5iorK0vPP/98o48VCsXFxZKklJQUW+tAZKMdwQq0I1iBdgQr0I5axrovK/XM9gpJUu/0GP1yRIKt9dxyyy1KSUnRjh07gt6XWwqD0Lt3b82fP187duxQUVGRDh06pGXLlqlz58569dVXddNNN9ldIgAAABDx3D6vI7yDix4uX+fq4arP0aNHNXDgQOXn52vdunW6+OKLm1WHp+erKQnaStnZ2ZKk8ePH21oHIhvtCFagHcEKtCNYgXbUMlZsOqz7X/lEkjSiR4b+/qNRttbTnOtzergs0KlTJ82cOVOS9NZbb9lcDQAAABDZmKUQdfTp00dSdW8XAAAAgKbznzQjsiMXgcsiBQUFkhhACQAAADSX7wifmAhPLBFefngwxmjVqlWSpGHDhtlcDQAAABDZ3D6ByxHhNxUSuBopLy9PTz/9tAoLC/2WFxUV6c4779SGDRuUlZWl6667zqYKAQAAgOjgN4YrsvNW637w8Ztvvql58+Z535eXl0uS3yyDc+bM0VVXXaXi4mLdfffd+vnPf67hw4erU6dOysvL0+bNm5Wfn6/09HStWLFCycnJLf45AAAAgGjiO2d4pI/hatWBKy8vTxs2bKiz3HdZXl6eJCkzM1MPPPCA1q9frz179ignJ0dOp1M9evTQjBkz9NOf/lSdO3dusdoBAACAaGXo4YoOM2bM0IwZMxq1bVpamh599NHQFgQAAABAbjezFAIAAABASPjfUmhbGZYgcAEAAAAIK36zFNLDBQAAAADW8RvDZWMdViBwAQAAAAgrfg8+pocLAAAAAKzj+xyumAhPLBFePgAAAIBo4zeGK8JvKiRwAQAAAAgrRtHzHC4CFwAAAICwwhguAAAAAAgR3wcfR3jeInABAAAACC/+Dz6O7MRF4AIAAAAQVnxnKYzwvEXgAgAAABBemKUQAAAAAELF9zlckZ23CFwAAAAAwoubWQoBAAAAIDQYwwUAAAAAIeI7S6EjwhMXgQsAAABAWHEzhgsAAAAAQsP4zlJI4AIAAAAA6xi/Hq7ITlwELgAAAABhhVkKAQAAACBEfMdwRToCFwAAAICwYujhAgAAAIDQMMxSCAAAAACh4WaWQgAAAAAIDTezFAIAAABAaPhOmeEgcAEAAACAdXzHcEV43iJwAQAAAAgvbnfNaybNAAAAAAALGTGGCwAAAABCwm+WQvvKsASBCwAAAEBYcfuN4YrsyEXgAgAAABBefHq4uKUQAAAAACzkZpZCAAAAAAgNt18Pl311WIHABQAAACCs8OBjAAAAAAgRbikEAAAAgBAxhudwAQAAAEBIGMZwAQAAAEBouOnhAgAAAIDQ8J2lMNIRuAAAAACEFcODjwEAAAAgNPwnzbCxEAsQuAAAAACEFf9p4SM7cRG4AAAAAIQV3yFc9HABAAAAgIV8J82ghwsAAAAALGT8bim0sRALELgAAAAAhBVmKQQAAACAEHEzSyEAAAAAhIbfLIWK7MRF4AIAAAAQVozfpBn21WEFAhcAAACAsMIYLgAAAAAIETezFAIAAABAaPg/+DiyExeBCwAAAEBYoYcLAAAAAELE7TdpRmQnLgIXAAAAgLBieA4XAAAAAIQGsxQCAAAAQIj4P/g4shG4AAAAAIQVxnBFiU2bNunRRx/Vddddpy5dusjhcDTqB7p06VKNGDFCqampysjI0JVXXqmcnJwWqBgAAACIftE0hivW7gLsNG/ePL322mtB7XPvvfdq0aJFSkpK0oQJE1RWVqY1a9bo7bff1ooVK/Ttb387NMUCAAAArYSJoh6uVh24Ro0apUGDBmn48OEaPny4unfvLpfLVe/2a9eu1aJFi5SZmal169apT58+kqR169Zp3LhxmjlzpsaNG6f09PQW+gQAAABA9HHTwxUdHnjggaC2X7BggSTpwQcf9IYtqTq43XHHHXryySf13HPP6b777rO0TgAAAKA18engYpbC1qK0tFTZ2dmSpGnTptVZ71n2+uuvt2hdAAAAQLRx+91TaF8dViBwNdLu3bvlcrnUoUMHdenSpc76YcOGSZK2bdvW0qUBAAAAUSWansPVqm8pDMbBgwclKWDYkqSUlBSlp6eroKBAhYWFSktLO+cx+/fvH3B5bm6usrKyvD1qdikuLpYk2+tAZKMdwQq0I1iBdgQr0I5aRnFJmff1tk+2quqw08Zqqn/uKSkpTdqXHq5GKioqkiQlJyfXu43nh1BYWNgiNQEAAADRyLeHK9LRw2WjHTt2BFzu6fkaP358S5ZTh+cvN3bXgchGO4IVaEewAu0IVqAdtYyEjdlSaakk6aKhQzW6d3tb62lq75ZED1ejpaamSpJKSkrq3cbTxdyY2wkBAAAABBZNz+EicDVS165dJUmHDx8OuL64uFinTp1Su3btCFwAAABAMxifxBXheYvA1Vh9+/ZVQkKC8vLydOTIkTrrN2/eLEkaNGhQS5cGAAAARBV3FM1SSOBqpKSkJO+9uq+88kqd9StWrJAkXXPNNS1aFwAAABBtjM+jj2MiO28RuIIxe/ZsSdKvf/1r7d2717t83bp1Wrx4sdLT03XrrbfaVR4AAAAQFdxRNIarVc9S+Oabb2revHne9+Xl5ZKkiy++2Ltszpw5uuqqqyRJl19+uWbNmqVFixZpyJAh+ta3vqXy8nKtWbNGxhgtWbJE6enpLfoZAAAAgGgTTWO4WnXgysvL04YNG+os912Wl5fnt27hwoUaMmSInnrqKa1Zs0bx8fG6/PLLNWfOHI0ePTrkNQMAAADRzkTRGK4WC1y7du3Sjh07dMEFF2jkyJEtddoGzZgxQzNmzGix/QAAAACcm9swhiugl19+WePHj6/Ta/Szn/1MAwYM0He+8x2NHj1aU6ZMUVVVlZWnBgAAABAlKt2+gSuyE5elgevFF1/U1q1bNXToUO+ynJwcPf7440pLS9MNN9yg7t276x//+IeWLVtm5akBAAAARIEqt1GRq9L7vk1inI3VNJ+lgWv79u0aNGiQ4uPjvcteeOEFORwO/f3vf9eyZcu0ceNGpaam6tlnn7Xy1AAAAACiQGFZhd8YrrZJBC6v48ePq3Pnzn7L3n33XXXs2FETJkyQJGVkZOib3/ymPv/8cytPDQAAACAKnC6t8L52OKS0xMie58/SwJWUlKQzZ8543x89elR79uzRpZde6rddenq6CgoKrDw1AAAAgChwqqQmcLVJjFNMhM+aYWng6tmzp95//32dOnVKkrRs2TI5HA5v75bHsWPH1LFjRytPDQAAACAK+PZwpSdH9u2EksWBa8aMGTpz5owuuugiTZ06VQ8++KBSU1M1efJk7zYVFRX6+OOPdeGFF1p5agAAAABR4JRP4Ir08VuSxc/h+uEPf6h3331Xr776qvbv36+UlBQtXrxYmZmZ3m3eeOMNnT59WuPHj7fy1AAAAACiwGkCV/3i4uL0yiuv6IsvvlBeXp6+9rWvKS0tzW+bHj16aNWqVbr44outPDUAAACAKHC6pNz7Oj05voEtI0NIpvzo3r27unfvHnDdkCFDNGTIkFCcFgAAAECE++qMy/u6bVJkz1AohShwSdLJkye1adMmnThxQt26ddPo0aNDdSoAAAAAUWDTgZN6Yf0B7/v0pMjv4bJ00gxJysvL0/e+9z1lZWVp4sSJuvHGG/0ecvzss88qIyNDH3zwgdWnBgAAABChyiqqNHPJRr9l3dun2FSNdSwNXCdPntTo0aP1t7/9TQMGDNBdd90l4/uYaEnXXXedCgsLtWLFCitPDQAAACCCfXWmTGfKKr3vZ4zurqsHdbKxImtYGrh+85vfKDc3Vw899JA2b96sP/7xj3W2ycjI0KBBg/Tee+9ZeWoAAAAAEay80u19HeOQ5l7bX4lxThsrsoalgWv16tW68MILNXfu3Aa369Wrl44cOWLlqQEAAABEMJdP4EpJiPzJMjwsDVxHjhzR4MGDz7mdw+HQmTNnrDw1AAAAgAhWXlUTuBJiLZ9qwjaWfpI2bdro6NGj59wuNzdXHTp0sPLUAAAAACKYq6ImcMU7CVwBDR8+XBs3btT+/fvr3eaTTz7R1q1bNWbMGCtPDQAAACCC+fVwRcHYLQ9LA9c999wjl8ulKVOmaNeuXXXWf/7557rppptkjNHdd99t5akBAAAARDDfSTPo4arHxIkT9V//9V/atm2bBgwYoK997WtyOBz617/+pcGDB6tfv37avn27fvnLX+qSSy6x8tQAAAAAIpirssr7Op4xXPV79NFH9fLLL2vgwIHas2ePjDE6evSoPv30U/Xp00fLli3TvHnzrD4tAAAAgAjm28MVTZNmWDrf4pkzZ+RwODR9+nRNnz5deXl5+uKLL+R2u9WlSxd17tzZytMBAAAAiBJ+txQSuAJLT0/XyJEjtW7dOklShw4dmI0QAAAAwDm5ojRwWfpJ2rZtq549e1p5SAAAAACtQLTeUmjpJxk6dKhyc3OtPCQAAACAVsB3Wvj4WKaFD+iBBx7Qxo0btWLFCisPCwAAACDKuSp8ZimMomnhLR3DlZSUpNtuu03f+c53dPXVV+uaa65R165dlZiYGHD7b37zm1aeHgAAAECEcvk9+JjAFdC4cePkcDhkjNHrr7+uN954o8Htq6qqGlwPAAAAoHWI1gcfWxq4fvCDH8jhcFh5SAAAAACtQLROmmFp4Fq6dKmVhwMAAADQSriiNHBFzycBAAAAELF48HGQysvLtXXrVh05ckSS1LlzZw0ZMkTx8fGhOiUAAACACEXgaqSysjI99NBDWrx4sYqKivzWpaam6o477tCvfvWremcuBAAAAND6uCprJtRLiKLncFkauFwuly6//HKtW7dOkjRo0CB1795dDodDX3zxhT755BPNnz9fH374od555x0lJCRYeXoAAAAAEcr/wcfR08Nl6Sd54oknlJOTozFjxmjr1q3asmWLVq1apZUrV2rz5s365JNPNHbsWK1bt04LFy608tQAAAAAIli0Tgtv6SdZvny5OnTooDfffFMDBw6ss37AgAF644031L59ey1btszKUwMAAACIYK4oHcNl6Sf5/PPPNW7cOKWlpdW7TWpqqsaNG6fc3FwrTw0AAAAggkXrc7gs/SSxsbEqKSk553YlJSWKjQ3ZBIkAAAAAIkxphc+kGXHRM2mGpYFr4MCBys7O1r59++rdZv/+/crOztagQYOsPDUAAACACGWM0fEzLu/79qnR8ygpSwPXj370I5WWlmrcuHF67rnnVFpa6l1XWlqqJUuWaNy4cSorK9Mdd9xh5akBAAAARKhCV6VfD9d5baLnEVKW3td300036YMPPtAzzzyj22+/Xbfffrvat28vSTpx4oSk6vT6ox/9SN///vetPDUAAACACHX8TJn3dZzToYxkerjqtXjxYr3yyiu65JJLFBcXp7y8POXl5SkuLk5jx47VK6+8oj//+c9WnxYAAABAhDp2uuZ2wo5piYqJcdhYjbVCMnPF1KlTNXXqVFVWVio/P1+SlJmZyUQZAAAAAOr4yqeHq2ObBBsrsV5IE1BsbKzOO++8UJ4CAAAAQIT7qrAmcJ2XFj3jt6QQPIfrySef1Pbt2+vdZvv27XryyScbnMkQAAAAQOvhO0PheVHWw2Vp4Fq4cKHuu+8+tWnTpt5t0tLSNHv2bD355JNWnhoAAABAhCotr5mhMDUxuoYhWRq43nnnHQ0ZMkRdu3atd5tu3bppyJAhWrNmjZWnBgAAABChXJU+Dz2OjZ6HHksWB65Dhw6pZ8+e59yuV69eOnz4sJWnBgAAABChXJVu7+uEWMsnUreVpZ/G6XTK5XKdczuXy6WqqqpzbgcAAAAg+pX5PPQ4MY4ernpdeOGF+uCDD1RSUlLvNiUlJfrggw/Up08fK08NAAAAIELRw9VI06ZN08mTJ3XbbbepuLi4zvqSkhL98Ic/VEFBgaZNm2blqQEAAABEKL/AFRddgcvSKUDuuecevfDCC3r55Zf17rvv6rvf/a569eolScrNzdXy5ct1/Phx9e3bV/fee6+VpwYAAAAQoaJ50gxLA1dycrLWrl2rG2+8UdnZ2Vq4cKEcDockyRgjSbrsssv0wgsvKCUlxcpTAwAAAIhQZRU1PVyJ9HA1LCsrS2vXrtXGjRu1du1aHTp0SJJ0wQUX6PLLL9fw4cOtPiUAAACACEYPVxMMHz7cG67Kysp06tQptW/fPlSnAwAAABChXBVMmhFQYWGhPvroI+3Zsyfg+r179+qKK65QmzZt1LlzZ6WlpemGG27Q0aNHm3NaAAAAAFGEaeHr8de//lWjRo3SG2+8UWfdsWPHNHbsWK1du1aVlZUyxsjlcumVV17RZZddptLS0uacGgAAAECUYFr4erz33nuKiYnRjTfeWGfdvHnzdPz4cbVr106vvfaaCgsLtXnzZn3jG9/Q3r179ec//7k5pwYAAAAQBYwxtQIXPVxe27Zt08CBA9WxY0e/5W63W8uXL5fD4dBvf/tbXXPNNUpJSdGQIUO0cuVKxcbGavXq1c05NQAAAIAo4Bu2pOibpbBZn8bzTK3atm3bplOnTik2NlY33HCD37rOnTtr5MiR2rVrV3NODQAAACAK1A5c9HD5KC4uVkVFRZ3lmzZtkiQNGjRIbdq0qbO+S5cuOnPmTHNObatx48bJ4XDU+/XWW2/ZXSIAAAAQEXynhJekhCjr4WrWtPAdOnTQZ599Vmf5Bx98IIfDoREjRgTcr6ysTG3btm3OqcPC1KlTlZqaWmd5586dbagGAAAAiDy+U8JL0TdpRrMC18iRI7V69Wq9/vrruuaaayRJeXl5WrlypSRpwoQJAffbuXOnzj///OacOizMnz9f3bt3t7sMAAAAIGL59nDFx8bI4XDYWI31mhUf7777bhljNH36dP3gBz/Qfffdp+HDh6uwsFAXXHCBrrrqqjr77Nu3T3v27NHgwYObc2oAAAAAUaAsih96LDWzh+uyyy7T3Llz9atf/UovvviiHA6HjDFKSkrSkiVLFBtb9/B/+ctfJElXXHFFc04NAAAAIAr4TpoRbQ89lpoZuCTpoYce0tVXX62VK1cqLy9PF1xwgb7//e+rR48eAbdPSEjQrFmzNHHixOae2nbPPfec8vPzFRMTowsvvFDf/va31bVrV7vLAgAAACKGq6LmlsJo7OFyGGOM3UVEmnHjxum9996rszwuLk5z5szRnDlzGnWc/v37B1yem5urrKwsPf/8882qs7mKi4slSSkpKbbWgchGO4IVaEewAu0IVqAdWW9bXpUWbimXJHVKceg3YxJtrqiuW265RSkpKdqxY0fQ+0ZfhGwB3/zmN/XCCy8oNzdXJSUl2r17t37zm98oNjZWDz30kBYtWmR3iQAAAEBE8J2kMMpmhJdED5el3n77bV1xxRVKT0/Xl19+qaSkpCYdx9Pz1ZQEbaXs7GxJ0vjx422tA5GNdgQr0I5gBdoRrEA7st7qLUd078tbJUnDuqZr5V1j7C0ogOZcn0dhhrTPhAkT9I1vfEOnTp3Shg0b7C4HAAAACHu+08InxEbfpBkELov16dNHknT06FGbKwEAAADCn+8shQlReE9h9H0imxUUFEhiICUAAADQGGU+sxQm0sOFhuTl5en999+XJA0bNszmagAAAIDw56qghws+cnJytHr1alVVVfkt/+KLLzRlyhQVFxfr2muvVZcuXWyqEAAAAIgcfrcURuFzuJr94OPWZs+ePZo5c6aysrI0bNgwpaen68CBA9q0aZPKysrUv39/PfPMM3aXCQAAAEQEv1sK46LvlkICV5BGjhypO++8Uxs2bNDGjRtVUFCglJQUDRkyRNOnT9edd97Z5OngAQAAgNaGHi746devn/70pz/ZXQYAAAAQFZgWHgAAAABCxLeHK5FJMwAAAADAOr5juOjhAgAAAAAL8eBjAAAAAAgR3+dw8eBjAAAAALBQme+kGfRwAQAAAIB1fHu4onFa+Oj7RAAAAAAiht+08FH44GMCFwAAAADblNHDBQAAAACh4TdLIZNmAAAAAIB1/G4ppIcLAAAAAKzjNy08Y7gAAAAAwBput1F5FWO4AAAAAMByvs/gkngOFwAAAABY5ujpMu9rZ4xDGcnxNlYTGgQuAAAAALY4XFDqfZ3VJlGxzuiLJ9H3iQAAAABEhCM+gatLuyQbKwkdAhcAAAAAWxwuKPG+7kzgAgAAAADrHDnl28OVbGMloUPgAgAAAGAL3zFcXdLp4QIAAAAAyzCGCwAAAABCoLzSra8Ka6aFZwwXAAAAAFjk6OlSGVP92uGQOrUlcAEAAACAJXzHb52Xlqj42OiMJtH5qQAAAACEtdYwfksicAEAAACwwWGfKeGjdfyWROACAAAAYIOvTtdMmJHVNtHGSkKLwAUAAACgxfnOUHheGoELAAAAACxzjB4uAAAAAAiN44Uu7+vz2iTYWEloEbgAAAAAtChXZZVOFpd735/Xhh4uAAAAALDE5gOn/N53ZAwXAAAAADRfldto1t+2eN+3T42P2oceSwQuAAAAAC3o37uP+43fmnbRBTZWE3oELgAAAAAt4kxZhR5cvd37/qJu7fTAxL42VhR6BC4AAAAALWL1liM66jMd/H0TLpTD4bCxotAjcAEAAABoEZ8dK/S+vrzfeRrdq72N1bQMAhcAAACAFvH58SLv60svjP6wJRG4AAAAALSQXJ/A1atjqo2VtBwCFwAAAICQKyguV77Pw457E7gAAAAAwBqf59X0brVJjFWH1AQbq2k5BC4AAAAAIec7fqt3x9Son53Qg8AFAAAAIORqB67WgsAFAAAAIORy8whcAAAAAGA5Y4x2fnnG+57ABQAAAAAW2Xn0jI4XuiRJDoc0uEu6vQW1IAIXAAAAgJBau/O49/XQC9KV2UpmKJQIXAAAAABCyO02enXzYe/7/9fvPBuraXkELgAAAAAhs25fvg6eLJEkxTikKUM721xRyyJwAQAAAAiJY6fL9P1nN3jfX3phB52fnmRjRS2PwAUAAADAcm630e0vfOy37IYRXW2qxj6xdhcAAAAAIHpUVrn17915+uO7n2vb4dPe5d0zkzX+ax1trMweBC4AAAAAzfbZsTNatfmI/rn9qA6dLPVblxzv1Ct3jFacs/XdYEfgAgAAANAsr209ontf3ipj6q4b1TNTT39/mDJS4lu+sDBA4AIAAADQLM++v79O2Bp8QbpuHtVNU4Z2lsPhsKewMEDgAgAAANAk+08U65HXd+jTIzVjtWaM7q6bRnVTrw6pNlYWPghcAAAAAILmdhv96IWPteerIu+yvuelae61/W2sKvwQuAAAAAA0Wkl5pdbvy9eitXv9wlac06FHJhO2aiNwAQAAAGiQMUafHjmttbuOa+mH+3WmrNJvfYxDev+/xiurbaJNFYYvAhcAAAAAP4cLSvTVGZfOlFboVGm53tp+TP/a8VXAbQd2bqvnZnxDHdMIW4EQuJqgtLRUv/vd7/S3v/1NBw8eVEZGhiZOnKh58+apc+fOdpcHAAAAnFNFlVsnilw6XVqh/KJy7TtRrJzPT2jjFyd1oqj8nPtf3u883Ty6my7p3b5Vz0J4LgSuIJWVlWn8+PFav369OnXqpMmTJ+uLL77QkiVL9MYbb2j9+vXq2bOn3WUCAACgFahyGxWWVaiwrFKnSyt0pqxCJa4qVVS5VV7lVllFlU6XVuhkcYWOF5Ypr9ClvEKXjhe6dLL43KGqtm6ZyRrbp71uGdNDPZmFsFEIXEH69a9/rfXr12vUqFF6++23lZpa3dAWLFig++67T7fccov+/e9/21skAAAAwp4xRq5Kt86UVqigpEIni8t1qqRcJ0vKVeKqUmlFlcoqqlTkqtSZ0upQdeZsuDpTWqEzZZUqclWe+0RNlJESr/TkOLVNilNmSoImDsjSlKGd5YyhNysYBK4glJeX66mnnpIkPf30096wJUmzZ8/W//7v/+q9997Tpk2bdNFFF9lVJgAAAIJUWeVWWWV1j5Dr7L/VX265KqvkqqhetvnLSlW4pSPrD6iyyq2KKreKXFUqLa9URZVRpdutyirj97q8yq3KKrcq3UZlFVXKLyr39kZVVJlzFxdiSXFOtU2KU9fMZPXqkKKLe2ZqWNd2uiAj2e7SogKBKwgffvihTp8+rV69emno0KF11k+bNk3btm3T66+/TuACAABRxRijKrdRlTFyu6Wqs+/d3mVGlW6jyrNBo+rs+yq3kTE+29far8p9NpycvQXO89p9dh+3u9Z+Rt7XnuXV76uXV1R5Ak/t8ORWmSc4+QQozzaV7iCDz87toflGN0NaYqzaJMYpJcGpOGeMYp0xSoyNUXpynNKT4tUhLUEd2ySoY1pC9eu0RHVIS1BinNPu0qMagSsIn3zyiSRp2LBhAdd7lm/btq3FagKAYBlTffFjPK8lmbPXGUam5rWpee+7rWed6lnv2U/eY9ZzLJ/z+J67bl1GR4rckqQ9XxX6H8v41+x7PM8x/c9R8z2o2ab2foH38Vt29kXdfWvOqTrnbHxdfpd9pp59TJ1NGqjL5/PWU5dpcJuac/t/733aRZ2fsf97zwG86/zaSz3nqPXeU0u956jVBmtvd+BAhSRpfdmueo/jPUc9n8HTJmu3v9ptv+5/E/V8rwIcQ37vq/dxn/2euc+GHSPjs0xnQ0f1AXzfe/YxZ5cZn3We9/6hpSbUVPkEqSqf/2YRGvGxMcpMiVe75Hi1S4lTWkKcEuNilBjnVEpCrNISY5WWWH17nydYtUny/Bun1IRYbvULUwSuIBw8eFCS1KVLl4DrPcsPHDjQqOP17x/4wXC5ubnKyspSdnZ2E6q0TnFxsSTZXkdr4/lFWGV09hdi9WvjXVZ3fc12ps4+biO5VfML3+3zC95dZ9nZfVSzr3eZzy99t8+6wO9rLjLKKypljLR0x1vVn89Ti+fYUp2LFd8aJP86JP8LH9/f/77ray5qAm1TO2T4b+/dtvY5ax2vTj21t6n3eCbgNuc8XhCfuXb9vseNaDn/sbsCRIMv9tldASKA0yHFO6W4GCk+xqG4s6+dclcvi3XK6ZCcMVKC06FEZ/Vrp8Oh2Bh51zkdntcOOR1SbIyUFu9QWpxDSbFSUqxDybHV56qe6a/i7FcA5We/Tkulqv4KPFE7rFZcXKyUlJQm7UvgCkJRUfWTtJOTA9/P6vkhFBYWtlhNqF+l26i8SnJVSa4q4/23vEqqdFcHEk9AqXLXBJgqz5fb+L2v2c74bOO/faVbqjy7vNJU1+A9l1s+Qcn4na92aIkunr+2VdlaBQCg5Tgd1Q/C9X5JcvgtcyhG1a8dZ0NI7Nl/nQ6HnDFn16lmH4f3OA6/43rXnX3v9DlGfMzZwOR0KC5GinOeDU8xnjDlOLus5n382e1i6pnm3PMH6ZSUhJb5ZiLiEbhstGPHjoDLPT1f48ePb8ly6vD0bNldR3mlWyeLy3WiyKX84nLlF7mUX1SuE8XV/+afXV5UVqmS8iqVlFeqtKIqLAahAmgcz3WNQ9UXU46zyxyqvuIybrcckpxOp3d7323l8973eJ5jynddrXP6buVbh//7mo3rblP3osy7Ta39G1OXah2/odoD1XWufRqqy/9bUnf/+n4+vu8dDv/XtWsJdBx59qvnOAp47to///qP4Tn+l0eOyCGpc5fOtbZ3+HzGWsepdQ7/74X/9qr1val9nHrPUet7Vfszxjg8/zq84cL3fYzD4bNMPttVb+OQFBNT6xiq2cYZc3a/GMfZoFK9nTPGE3xqL/N57XAoJkbeZbFn10fzM5nC5foILaupvVsSgSsonlkJS0pKAq73/MUjLS2txWqKZJ7nRlSHqOrgdKLIpRNF1eFqy8FTOlxQojNloZvuNFLFOOT9pef55RjrjPH7Zef7VfML2eHdN9AvZs8vTr9ta+0X43Cc/aVc/cs5xrtfzS9mz5czRjp06JAckrp37+ZXh0M15445e895QxcP57qgkc+ygBdntS6WVHt9rQugs1sEvsj0OU/1VvUfy/dC0G9bv3MFuPCrdSzVPrb3eA0fy/fz1DlWgJo95657AVvrItuzTT3rfc9Tt07f72XjLsq4wIEVsrPzJEnjxw+0uRIArQmBKwhdu3aVJB0+fDjges/ybt26tVhNkaC80q0dX57Wxi9O6u0dX+lwQakKyypUXN5yt5g5YxxKjnMqKd6p5Hin4mNjFBsTo1hndSCJi4k5G1qqA0usM8YbXOKcMWf/PRtsYmK8y+OcDu9xEmKr94mLjVGcM0bxzpizMwQ5FO/0Pf7Z1z6BKDbG59hna/CGJ2dNqKr5a2Lk/OUwO/uYJGn8+K/ZXAkAAEDLI3AFYfDgwZKkzZs3B1zvWT5o0KAWqyncGGNUVuHWoYIS/d+nx7TvRJH+vTtPp0vrGfwZhLSEWGWmxiszNUGZKdX/tk+NV2ZKvDJSE9QmMVbJ8bFKjq8JVklnQ1a8Myaqb28AAABAeCJwBWHMmDFq27atcnNztXXrVg0ZMsRv/YoVKyRJ11xzjQ3V2evLU6V6+B879N7uPJVXuYPa1xnj8AtQ7c8Gqi7tkjSkazt1SKt+zzMiAAAAEGkIXEGIj4/X3Xffrd/85jf68Y9/rLfffts7gG7BggXatm2bLr300lb30OMX1h/QnNUNP/wv3hmjQV3aakSPDI3u1V7tUuLUJrHmORKRdIscAAAA0FgEriA9+OCDWrt2rXJyctSnTx+NHTtWBw4c0IYNG9ShQwc9//zzdpfYYo6cKtXv/rlLb2w7GnB9zw4p+n9f66humSm6elAnpSfHt3CFAAAAgL0IXEFKTEzUu+++q9/97nd66aWXtHr1amVkZGjGjBmaN29evQ9FjiZlFVV67F+79dKGgyqt8J/44vZv9tRtY3uobVKcEmK5BRAAAACtG4GrCZKSkvTII4/okUcesbsUW/z6zZ16cf1Bv2X9z2+jP31/mLplNv0ZBQAAAEC0IXAhKPP/tbtO2Pr91IGadtEFcjIOCwAAAPBD4EKjVbmN/uc/+/yW/fv+cerenl4tAAAAIJAYuwtA5MgvcvlN+f7bKQMJWwAAAEADCFxotOOFLu/r1IRYfW9kVxurAQAAAMIfgQuNdrywzPu6Y1qCjZUAAAAAkYHAhUY7fqamh6sDgQsAAAA4JwIXGu0rn8DVsU2ijZUAAAAAkYHAhUbjlkIAAAAgOAQuNNrBkyXe1wQuAAAA4NwIXGiUg/kl+uDzE973Xz+/jY3VAAAAAJGBwIVGWZrzhYypft2zfYrG9Gpvb0EAAABABCBw4ZxKy6v0948Ped/PHNNdMTEOGysCAAAAIgOBC+f0+fEiFbkqJUmJcTG6blgXmysCAAAAIgOBC+d0qrTc+7pjWqJSEmJtrAYAAACIHAQunNPp0grv67ZJcTZWAgAAAEQWAhfO6VRJTeBKTyZwAQAAAI1F4MI50cMFAAAANA2BC+dE4AIAAACahsCFczpVUjNpBrcUAgAAAI1H4MI50cMFAAAANA2BC+fkN2lGUryNlQAAAACRhcCFc/Lr4eKWQgAAAKDRCFw4p5PFNWO4uKUQAAAAaDwCFxp0ptzoeKHL+75rRrKN1QAAAACRhcCFBu0/7fa+7pCWoE5tE22sBgAAAIgsBC40yDdwDe6SLofDYWM1AAAAQGQhcKFBeaXG+/rrndJsrAQAAACIPAQuNKiqJm8pIc5pXyEAAABABCJwoUHGJ3BxNyEAAAAQHAIXGuT2SVxOEhcAAAAQFAIXGuTTwaUYAhcAAAAQFAIXGuT2SVwxMQQuAAAAIBgELjTIdwwXeQsAAAAIDoELDXL7vHaSuAAAAICgELjQIP9ZCglcAAAAQDAIXGiQ7xguZikEAAAAgkPgQoP8Zym0rQwAAAAgIhG40CBmKQQAAACajsCFBvnPUkjgAgAAAIJB4EKD/GcptK0MAAAAICJxCY0GGZ8uLnq4AAAAgOAQuNAgN7cUAgAAAE1G4EKDePAxAAAA0HQELjTIf9IM++oAAAAAIhGBCw3ilkIAAACg6QhcaBDTwgMAAABNR+BCgxjDBQAAADQdgQsN8r2lkA4uAAAAIDgELjTIJ2/RwwUAAAAEicCFBjFpBgAAANB0BC40iEkzAAAAgKYjcKFBbp7DBQAAADQZgQsNMj6juBjDBQAAAASHwIUG+c9SSOACAAAAgkHgQoN8x3DRwwUAAAAEh8CFBvk++Ji8BQAAAASHwIUGMUshAAAA0HQELjSI53ABAAAATUfgQoN88hZjuAAAAIAgEbjQIJ7DBQAAADQdgSsI//73v+VwOOr9uvjii+0u0XJ+Y7hIXAAAAEBQYu0uIBL16tVLl1xyScDl0cZ/lkICFwAAABAMAlcTXHLJJVq6dKndZbQIv+dwEbgAAACAoHBLIepljPGbNIO8BQAAAASHwIV6mVrvmaUQAAAACA63FDbB3r179Ytf/EL5+flq3769LrnkEk2cOFExMdGVX921EhdjuAAAAIDgELiaICcnRzk5OX7LBg4cqFdffVV9+vRp9HH69+8fcHlubq6ysrKUnZ3drDqbq6i4RL6doOtyPlTbBEIXglNcXCxJtrdnRDbaEaxAO4IVaEetU3FxsVJSUpq0b3R1yYRY27Zt9bOf/Uzr169Xfn6+8vPz9c477+jiiy/Wp59+qgkTJuj06dN2l2mZ2j1cRC0AAAAgOA5jTO2hOlFrypQp2rVrV1D7/PWvf9WIESMa3KaqqkqXXXaZ3n//ff32t7/VL37xi+aU6e352rFjR7OO01xvvv2Ofpxd5n2/Zc631C4l3saKEIk8fwEcP368zZUgktGOYAXaEaxAO2qdmnN93qpuKdy/f792794d1D4lJSXn3MbpdOqBBx7Q+++/r3/961/NDlzhonYUZwwXAAAAEJxWFbi2bt0asmN7xm4dPXo0ZOdoae5a76NsThAAAAAg5LiEtkhBQYEkNXkwXThilkIAAACgeQhcFnn11VclScOGDbO5EuvUvqWQ53ABAAAAwSFwBWHhwoU6dOiQ3zJjjBYvXqwnnnhCDodDd955p03VWa/OLIXkLQAAACAorWoMV3MtXLhQ999/v4YNG6YePXqorKxMn376qfbv36+YmBg9+eSTuuiii+wu0zK1p690krgAAACAoBC4gnDffffp7bff1o4dO7Rz505VVFSoU6dOuvHGG/WTn/xEw4cPt7tES7lr3VPIGC4AAAAgOASuINxzzz2655577C6jxdSZFp4xXAAAAEBQGMOFevlOC0/WAgAAAIJH4EK9fHu4mKEQAAAACB6BC/XynaXQwfgtAAAAIGgELtTLdwgXMxQCAAAAwSNwoV5ubikEAAAAmoXAhXoZv1sK7asDAAAAiFQELtTLd5ZCergAAACA4BG4UC/fHi4eegwAAAAEj8CFerkJXAAAAECzELhQL+MzTyF3FAIAAADBI3ChXsxSCAAAADQPgQv1YgwXAAAA0DwELtTLbwwXLQUAAAAIGpfRqJfvtPD0cAEAAADBI3ChXr63FDoJXAAAAEDQCFyol+8theQtAAAAIHgELtTL95ZCZikEAAAAgkfgQr2YpRAAAABoHgIX6uUmcAEAAADNQuBCvXzyFtPCAwAAAE3AZTTq5fa5p5BZCgEAAIDgEbhQL+M3SyGBCwAAAAgWgQv1YpZCAAAAoHkIXKhXlU/iIm8BAAAAwSNwoV4HC2sSV1bbJBsrAQAAACITgQv12lNQE7hG9MiwsRIAAAAgMhG4EFCxq1KHCmtmzRhJ4AIAAACCRuBCQAUl5X7P4erZPsW2WgAAAIBIReBCQOWVNbcTxsY4FOukqQAAAADB4ioaAbl8AldCLM0EAAAAaAqupBGQX+CKc9pYCQAAABC5CFwIyPeWwnhuJwQAAACahCtpBOSqrPK+ToijmQAAAABNwZU0AnJVMIYLAAAAaC6upBGQ7xiueAIXAAAA0CRcSSOg8iqfWwpjmTQDAAAAaAoCFwLilkIAAACg+biSRkDcUggAAAA0H1fSCKicBx8DAAAAzcaVNALymxaeMVwAAABAkxC4EBC3FAIAAADNx5U0AuKWQgAAAKD5uJJGQC6/wMUthQAAAEBTELgQkN8YrjiaCQAAANAUXEkjIL8xXE6aCQAAANAUXEkjIL9bCunhAgAAAJqEK2kE5KpgDBcAAADQXAQuBFRexbTwAAAAQHNxJY2AXBW+Dz6mmQAAAABNwZU0AnLxHC4AAACg2biSRkAELgAAAKD5uJJGQOW+z+Fi0gwAAACgSWLtLgDhaUDntnJUlKqiyig9Oc7ucgAAAICIROBCQItuGKrs7GxJ0tCu7WyuBgAAAIhM3FIIAAAAACFC4AIAAACAECFwAQAAAECIELgAAAAAIEQIXAAAAAAQIq02cBUXF+uFF17QPffco5EjRyohIUEOh0Nz5849576HDx/WzJkzdf755ysxMVEXXnihHn74YZWVlYW+cAAAAAARo9VOC79371794Ac/CHq/zz//XKNGjdKJEyc0YMAAjR07Vh9//LEeeeQRvfPOO3rnnXeUkJAQgooBAAAARJpW28OVlpamW2+9VX/5y1+0adMmPfLII43ab8aMGTpx4oR+8pOf6NNPP9XLL7+s3bt3a8qUKfrwww/1u9/9LsSVAwAAAIgUrTZw9erVS88++6x+9KMfadiwYYqLizvnPh999JE+/PBDdezYUX/4wx+8y2NjY/XnP/9ZcXFxevLJJ1VZWRnK0gEAAABEiFYbuJrizTfflCRdc801dW4bPO+88zR27FgVFBTogw8+sKM8AAAAAGGGwBWETz75RJI0bNiwgOs9y7dt29ZiNQEAAAAIX6120oymOHjwoCSpS5cuAdd7lh84cKBRx+vfv3/A5bm5ucrKylJ2dnYTqrROcXGxJNleByIb7QhWoB3BCrQjWIF21DoVFxcrJSWlSfvSwxWEoqIiSVJycnLA9Z4fQmFhYYvVBAAAACB8RWwP15QpU7Rr166g9vnrX/+qESNGhKii4O3YsSPgck/P1/jx41uynDo8f7mxuw5ENtoRrEA7ghVoR7AC7ah1amrvlhTBgWv//v3avXt3UPuUlJQ065ypqakNHsfTxZyWltas8wAAAACIDhEbuLZu3dri5+zatau2bNmiw4cPB1zvWd6tW7eWLAsAAABAmGIMVxAGDx4sSdq8eXPA9Z7lgwYNarGaAAAAAIQvAlcQrrrqKknS66+/LpfL5bfuq6++0vvvv6927dppzJgxdpQHAAAAIMwQuIIwYsQIjRkzRsePH9cDDzzgXV5ZWam77rpLFRUV+slPfqK4uDgbqwQAAAAQLiJ2DJcVpkyZoqNHj0qSvvzyS0nSs88+q7feekuS1KlTJ61atcpvnyVLlmjUqFFatGiRsrOz9fWvf10bN27Uvn37NHr0aP3iF79o2Q8BAAAAIGy16sC1ZcuWOg8pPnLkiI4cOSIp8OQXffr00ZYtW/TQQw/prbfe0qpVq9S1a1fNmTNHv/zlL5WQkNAitQMAAAAIfw5jjLG7CPhLS0tTRUWFevXqZWsdnmnum/PcAYB2BCvQjmAF2hGsQDtqnXJzcxUXF6fCwsKg92UMVxhKSUkJi3Fgx44d07Fjx+wuAxGOdgQr0I5gBdoRrEA7ap3i4uKaHLLp4UK9+vfvL0nasWOHzZUgktGOYAXaEaxAO4IVaEcIFj1cAAAAABAiBC4AAAAACBECFwAAAACECIELAAAAAEKEwAUAAAAAIcIshQAAAAAQIvRwAQAAAECIELgAAAAAIEQIXAAAAAAQIgQuAAAAAAgRAhcAAAAAhAiBCwAAAABChMAFAAAAACFC4AIAAACAECFwwU9paakeeughXXjhhUpMTNT555+vW265RUeOHLG7NLSwkpISrV69Wrfeeqv69u2rxMREpaSkaPDgwXrkkUdUVFRU775Lly7ViBEjlJqaqoyMDF155ZXKyclp8HwffvihrrzySmVkZCg1NVUjRozQX//6V6s/FsJAfn6+OnbsKIfDod69eze4LW0JgeTl5en+++9X3759lZSUpIyMDA0bNkw/+9nPAm7/+uuv69JLL1WbNm3Upk0bjRs3Tm+++WaD59ixY4emT5+uDh06KCkpSQMHDtTChQvldrtD8ZHQwjZu3Kjrr79e559/vuLi4pSenq6xY8dqyZIlMsbU2b6qqkpPPPGEBg4cqKSkJHXo0EHXX3+9du3a1eB5mtL2EIUMcFZpaam5+OKLjSTTqVMnc/3115sRI0YYSaZDhw4mNzfX7hLRgp555hkjyUgy/fr1M9OnTzdXXHGFSUtLM5LM1772NfPVV1/V2W/WrFlGkklKSjKTJ082V1xxhYmNjTVOp9OsWrUq4LlWrFhhnE6ncTgc5tJLLzVTp0416enpRpK57777QvxJ0dJuvvlm43A4jCTTq1everejLSGQjz/+2GRmZhpJpn///uY73/mOmTRpkunWrZtxOp11tn/iiSeMJBMbG2smTpxoJk+ebJKSkowk88c//jHgOXJycrzbjBgxwlx//fUmKyvLSDLTp083brc71B8TIeT5/4QkM2zYMHP99debyy67zMTGxhpJ5nvf+57f9lVVVWbKlClGkklPTzdTp041l156qXE4HCY5Odls2LAh4Hma0vYQnQhc8Prv//5vI8mMGjXKFBYWepc//vjjRpK59NJL7SsOLW7p0qXm9ttvNzt37vRb/uWXX5qhQ4caSea73/2u37o1a9YYSSYzM9Ps2bPHuzwnJ8fEx8eb9PR0U1BQ4LdPfn6+adOmjZFkXn31Ve/yY8eOmd69extJ5t1337X888Eea9euNZLM7bff3mDgoi0hkOPHj5v27dub5ORk89prr9VZX/vC97PPPjNOp9MkJCSYnJwc7/Ldu3ebzMxMExsba/bu3eu3T3l5uenRo4eRZBYsWOBdXlhYaEaNGmUkmSVLllj7wdBiKioqTMeOHY0ks2zZMr91O3fuNBkZGUaSyc7O9i73/AGyT58+5tixY97lK1asMJJM7969TUVFhd+xmtL2EL0IXDDGGONyuUzbtm2NJLN58+Y66wcNGmQkmY8//tiG6hBucnJyjCSTkJBgXC6Xd/mkSZOMJPPEE0/U2ecnP/mJkWTmz5/vt/z3v/+9kWQmT55cZ5+VK1caSebqq6+2+iPABiUlJaZXr17m61//utmzZ0+DgYu2hEDuvPNOI8k8/fTTQW0/a9asOusWLFhgJJm7777bb/nLL79sJJnBgwfX2WfTpk1GkhkwYEBTykcY+PTTT40k07dv34DrPf9/+f3vf+9d1q9fPyMpYM/6tddeaySZFStW+C1vSttD9GIMFyRVj3k4ffq0evXqpaFDh9ZZP23aNEnV9yIDgwcPliS5XC7l5+dLqh7/l52dLammvfiqrw157mUPtM9VV12lxMRErV27VmVlZdZ9ANjiV7/6lfbt26e//OUviouLq3c72hICKS0t1YsvvqiUlBTNnDmzUfs01Caa0o6GDRumnj17avv27friiy+CKR9hIiEhoVHbZWZmSpL279+vXbt2KSkpSVdddVWd7ZrSjriman0IXJAkffLJJ5Kqf5kE4lm+bdu2FqsJ4Wvfvn2SpLi4OGVkZEiSdu/eLZfLpQ4dOqhLly519qmvDTXU9uLj4zVgwACVlZVpz549ln4GtKxt27bp8ccf18yZMzV27NgGt6UtIZCPP/5YhYWFGjp0qJKSkvR///d/mj17tu666y4tXLhQX375pd/2p06d0sGDByUp4B8SL7jgArVv314HDhzQmTNnvMv5fRjdevbsqV69emn37t166aWX/Nbt2rVLL774otq1a6cpU6ZIqmkPAwYMCPiHokDtoaltD9GLwAVJ8v6PIdDFje/yAwcOtFhNCF+LFi2SJE2cONH718JztaGUlBSlp6eroKBAhYWFkqQzZ87o9OnTDe5H24t8brdbt912m9LT0/WHP/zhnNvTlhDIzp07JUkdO3bUt7/9bV155ZV64okn9Oc//1k//elP1bt3by1fvty7vacdtWvXTikpKQGPGahN8PswujmdTv3v//6v0tPT9f3vf18XXXSRbrjhBo0fP16DBg1Sly5d9M4773j/mNiU9tDUtofoReCCJHmn+E5OTg643vM/DM/FDVqvf/7zn3ruuecUFxenefPmeZefqw1JdduR79TytL3o9cc//lEbN27UY4895r1NpyG0JQRSUFAgSfrHP/6ht956S08//bSOHz+uL774Qvfff79KS0t18803a+vWrZKa1o4asx/tKPKNGTNG7733nnr27KnNmzfr5Zdf1rvvvquYmBh961vfUs+ePb3bNqU9NLXtIXoRuAA02meffaYbb7xRxhg99thj3rFcQH0OHjyoBx98UJdeeqlmzJhhdzmIYJ7nX1VWVuqRRx7RXXfdpQ4dOqhbt2567LHHNH36dFVUVOixxx6zuVKEu+XLl2vEiBG64IILtGHDBhUVFWnPnj2aMWOGHn/8cY0fP14ul8vuMhFFCFyQJKWmpkqqfthtIMXFxZKktLS0FqsJ4eXIkSOaOHGiCgoKNHv2bM2aNctv/bnakFS3HXn2aWg/2l5k+/GPf6zy8nL95S9/afQ+tCUE4vszDjRphmfZe++957d9MO2oMfvRjiLb3r17dfPNN6t9+/Z64403NGLECKWkpKhPnz5avHixrr76am3evFnPP/+8pKa1h6a2PUQvAhckSV27dpUkHT58OOB6z/Ju3bq1WE0IHydPntSECRN04MABzZw5U/Pnz6+zzbnaUHFxsU6dOqV27dp5f8G0adNGbdu2bXA/2l5ke+ONN5ScnKw77rhD48aN837dcMMNkqqDvGfZsWPHJNGWEJjn55acnKwOHTrUWd+9e3dJ0vHjxyXVtKOCggLvxW1tgdoEvw+j29/+9jdVVFRo4sSJfiHe4/rrr5ck/ec//5HUtPbQ1LaH6EXggqSaab43b94ccL1n+aBBg1qsJoSHoqIiTZo0STt37tR1112nZ555Rg6Ho852ffv2VUJCgvLy8nTkyJE66+trQw21vYqKCm3fvl2JiYm68MILrfg4sMGpU6f03nvv+X1t2LBBklRWVuZd5pmunbaEQDyzvZWWlga83evkyZOSanoX0tPTvRe+W7ZsqbP9oUOHdOLECXXr1k1t2rTxLuf3YXTzBB3PH2hq8yz3jBn0tIft27eroqKizvaB2kNT2x6iF4ELkqoHkLZt21a5ubneAce+VqxYIUm65pprWrgy2Mnlcmny5Mn66KOPdMUVV2j58uVyOp0Bt01KStL48eMlSa+88kqd9fW1Ic9zTTzrfb3xxhsqKyvT5ZdfrsTExGZ9FtjDGBPwa//+/ZKkXr16eZd5eihoSwika9euGjx4sIwx3tsGfXmW+U7D3VCbaEo72rJli/bt26cBAwZ42ysiS1ZWlqTqxwwEsnHjRkk1PaY9evRQv379VFpa6n22lq+mtCOuqVohu564jPDz3//930aSGT16tCkqKvIuf/zxx40kc+mll9pXHFpcZWWlmTJlipFkxo4da4qLi8+5z5o1a4wkk5mZafbs2eNdnpOTYxISEkx6eropKCjw2yc/P9+0adPGSDKvvvqqd/lXX31levfubSSZd99916qPhTCxf/9+I8n06tUr4HraEgJZtmyZkWQGDhxovvzyS+/yLVu2mIyMDCPJ/P3vf/cu/+yzz4zT6TQJCQlm3bp13uV79uwxmZmZJjY21uzdu9fvHOXl5aZHjx5GklmwYIF3eVFRkRk1apSRZJYsWRK6D4mQ2rRpk5FkJJk//elPfuvWrVtnUlJSjCSzZs0a7/JnnnnGSDJ9+vQxX331lXf5q6++aiSZ3r17m4qKCr9jNaXtIXoRuOBVWlpqRo4caSSZTp06meuvv977vkOHDiY3N9fuEtGCFi5c6P2lNGXKFHPzzTcH/MrLy/Pbb9asWUaSSU5ONpMnTzaTJk0ysbGxxul0mlWrVgU814oVK0xMTIxxOBzmsssuM9OmTTPp6elGkpk9e3YLfFq0tHMFLmNoSwjs5ptvNpJMenq6ufLKK81ll11mEhISjCTzwx/+sM72CxYsMJJMbGysmTRpkpk8ebJJSkoyksyTTz4Z8Bwffvihd5uRI0ea66+/3nTq1MlIMtOmTTNutzvUHxMhdP/993t/v/Xv399Mnz7djBkzxsTExBhJ5vbbb/fbvqqqyvsHyHbt2plp06aZcePGGYfDYZKSksz69esDnqcpbQ/RicAFPyUlJWbOnDmmV69eJj4+3mRlZZkZM2aYQ4cO2V0aWtjDDz/s/YXU0Nf+/fvr7LtkyRJz0UUXmeTkZJOenm4mTpxoPvzwwwbP98EHH5iJEyea9PR0k5ycbL7xjW+YpUuXhujTwW6NCVzG0JZQl9vtNv/zP//jbRcpKSlm1KhRDf6M//GPf5ixY8ea1NRUk5qaasaOHWtef/31Bs+zfft2M3XqVJOZmWkSExNN//79zYIFC0xVVZXVHwk2WLlypZkwYYK3t6ldu3bmsssuMy+99FLA7SsrK83jjz9u+vfvbxITE01mZqaZNm2a2bFjR4PnaUrbQ/RxGGOM1bcpAgAAAACYNAMAAAAAQobABQAAAAAhQuACAAAAgBAhcAEAAABAiBC4AAAAACBECFwAAAAAECIELgAAAAAIEQIXAAAAAIQIgQsAAAAAQoTABQAAAAAhQuACAAAAgBAhcAEAAABAiBC4AAAAACBECFwAAAAAECIELgAAAAAIEQIXAAAAAIQIgQsAAAAAQuT/A/64OGshKrwjAAAAAElFTkSuQmCC", "text/plain": [ "
" ] diff --git a/examples/wikitext/README.md b/examples/wikitext/README.md index 635c7bd..9acccd0 100644 --- a/examples/wikitext/README.md +++ b/examples/wikitext/README.md @@ -25,7 +25,7 @@ This will fine-tune the model using the specified hyperparameters and save the f ## Computing Pairwise Influence Scores -To compute pairwise influence scores using the `ekfac` factorization strategy, run the following command: +To compute pairwise influence scores using the `ekfac` strategy, run the following command: ```bash python analyze.py --query_batch_size 32 \ @@ -34,26 +34,26 @@ python analyze.py --query_batch_size 32 \ --factor_strategy ekfac ``` -You can also use `identity`, `diagonal`, and `kfac` for `factor_strategy`. On an A100 (80GB) GPU, this process takes approximately 50 minutes. +You can also use `identity`, `diagonal`, and `kfac` for `factor_strategy`. On an A100 (80GB) GPU, this process takes approximately 40 minutes: ``` ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 2790.6 | 100 % | +| Total | - | 11 | 2357.4 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 1888.2 | 1 | 1888.2 | 80.098 | +| Fit Lambda | 274.64 | 1 | 274.64 | 11.651 | +| Fit Covariance | 180.27 | 1 | 180.27 | 7.6471 | +| Perform Eigendecomposition | 7.7754 | 1 | 7.7754 | 0.32984 | +| Save Eigendecomposition | 3.0652 | 1 | 3.0652 | 0.13003 | +| Save Covariance | 2.6799 | 1 | 2.6799 | 0.11368 | +| Save Lambda | 0.66036 | 1 | 0.66036 | 0.028013 | +| Load Covariance | 0.033343 | 1 | 0.033343 | 0.0014144 | +| Save Pairwise Score | 0.016471 | 1 | 0.016471 | 0.0006987 | +| Load All Factors | 0.0086084 | 1 | 0.0086084 | 0.00036517 | +| Load Eigendecomposition | 0.0054964 | 1 | 0.0054964 | 0.00023316 | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 2253.1 | 1 | 2253.1 | 80.739 | -| Fit Lambda | 292.74 | 1 | 292.74 | 10.49 | -| Fit Covariance | 194.38 | 1 | 194.38 | 6.9654 | -| Perform Eigendecomposition | 22.295 | 1 | 22.295 | 0.79893 | -| Save Covariance | 12.157 | 1 | 12.157 | 0.43564 | -| Save Eigendecomposition | 11.641 | 1 | 11.641 | 0.41716 | -| Save Lambda | 3.0458 | 1 | 3.0458 | 0.10915 | -| Load Covariance | 0.48773 | 1 | 0.48773 | 0.017478 | -| Load Eigendecomposition | 0.45834 | 1 | 0.45834 | 0.016425 | -| Load All Factors | 0.18377 | 1 | 0.18377 | 0.0065855 | -| Save Pairwise Score | 0.10407 | 1 | 0.10407 | 0.0037292 | -----------------------------------------------------------------------------------------------------------------------------------" ``` For more efficient computation, use half precision: @@ -66,66 +66,35 @@ python analyze.py --query_batch_size 32 \ --use_half_precision ``` -This reduces computation time to about 20 minutes on an A100 (80GB) GPU: +This reduces computation time to about 15 minutes on an A100 (80GB) GPU: ``` ---------------------------------------------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ---------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 1211.8 | 100 % | +| Total | - | 11 | 785.92 | 100 % | ---------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 1034.5 | 1 | 1034.5 | 85.368 | -| Fit Lambda | 88.231 | 1 | 88.231 | 7.2811 | -| Fit Covariance | 59.746 | 1 | 59.746 | 4.9305 | -| Perform Eigendecomposition | 14.831 | 1 | 14.831 | 1.2239 | -| Save Covariance | 5.8912 | 1 | 5.8912 | 0.48617 | -| Save Eigendecomposition | 5.7726 | 1 | 5.7726 | 0.47638 | -| Save Lambda | 1.624 | 1 | 1.624 | 0.13402 | -| Load Covariance | 0.34494 | 1 | 0.34494 | 0.028465 | -| Load Eigendecomposition | 0.33595 | 1 | 0.33595 | 0.027724 | -| Load All Factors | 0.26719 | 1 | 0.26719 | 0.022049 | -| Save Pairwise Score | 0.26006 | 1 | 0.26006 | 0.021461 | +| Compute Pairwise Score | 654.62 | 1 | 654.62 | 83.294 | +| Fit Lambda | 74.662 | 1 | 74.662 | 9.4999 | +| Fit Covariance | 45.784 | 1 | 45.784 | 5.8256 | +| Perform Eigendecomposition | 7.5987 | 1 | 7.5987 | 0.96685 | +| Save Eigendecomposition | 1.4441 | 1 | 1.4441 | 0.18375 | +| Save Covariance | 1.3445 | 1 | 1.3445 | 0.17107 | +| Save Lambda | 0.38279 | 1 | 0.38279 | 0.048705 | +| Load Covariance | 0.058189 | 1 | 0.058189 | 0.0074039 | +| Save Pairwise Score | 0.0094807 | 1 | 0.0094807 | 0.0012063 | +| Load All Factors | 0.0083676 | 1 | 0.0083676 | 0.0010647 | +| Load Eigendecomposition | 0.0053729 | 1 | 0.0053729 | 0.00068364 | ---------------------------------------------------------------------------------------------------------------------------------- ``` The `half_precision_analysis.py` script compares the correlations between `float32` and `bfloat16` scores.

-Query Batching +Half Precision

-The average correlation for 481 data points is `0.96`. Finally, we can try using `torch.compile`: - -```bash -python analyze.py --query_batch_size 32 \ - --train_batch_size 64 \ - --checkpoint_dir ./checkpoints \ - --factor_strategy ekfac \ - --use_half_precision \ - --use_compile -``` - -This reduces computation time to about 16 minutes on an A100 (80GB) GPU: - -``` ----------------------------------------------------------------------------------------------------------------------------------- -| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | ----------------------------------------------------------------------------------------------------------------------------------- -| Total | - | 11 | 939.4 | 100 % | ----------------------------------------------------------------------------------------------------------------------------------- -| Compute Pairwise Score | 735.99 | 1 | 735.99 | 78.347 | -| Fit Covariance | 103.6 | 1 | 103.6 | 11.029 | -| Fit Lambda | 69.442 | 1 | 69.442 | 7.3922 | -| Perform Eigendecomposition | 16.011 | 1 | 16.011 | 1.7044 | -| Save Covariance | 5.9458 | 1 | 5.9458 | 0.63294 | -| Save Eigendecomposition | 5.9252 | 1 | 5.9252 | 0.63074 | -| Save Lambda | 1.5185 | 1 | 1.5185 | 0.16164 | -| Load Covariance | 0.42047 | 1 | 0.42047 | 0.04476 | -| Load Eigendecomposition | 0.32199 | 1 | 0.32199 | 0.034276 | -| Load All Factors | 0.16436 | 1 | 0.16436 | 0.017496 | -| Save Pairwise Score | 0.055834 | 1 | 0.055834 | 0.0059436 | ----------------------------------------------------------------------------------------------------------------------------------- -``` +The average correlation for 481 data points is `0.96`. ## Counterfactual Experiment @@ -139,7 +108,8 @@ This reduces computation time to about 16 minutes on an A100 (80GB) GPU: ## Evaluating Linear Datamodeling Score The `evaluate_lds.py` script computes the [linear datamodeling score (LDS)](https://arxiv.org/abs/2303.14186). It measures the LDS obtained by -retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). We obtain `0.43` LDS (we get `0.41` LDS with the half precision). +retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). We obtain `0.44` LDS +(`0.42` LDS with the half precision and `0.12` LDS with the `identity` strategy). The script also includes functionality to print out top influential sequences for a given query. diff --git a/examples/wikitext/analyze.py b/examples/wikitext/analyze.py index 9a01db0..8afbc30 100644 --- a/examples/wikitext/analyze.py +++ b/examples/wikitext/analyze.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import Dict, List, Optional +from typing import Dict, List import torch import torch.nn.functional as F @@ -90,23 +90,21 @@ def compute_train_loss( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], ).logits - - shift_logits = logits[..., :-1, :].contiguous() + logits = logits[..., :-1, :].contiguous() + logits = logits.view(-1, logits.size(-1)) if not sample: labels = batch["labels"] - shift_labels = labels[..., 1:].contiguous() - reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - summed_loss = F.cross_entropy(reshaped_shift_logits, shift_labels.view(-1), reduction="sum") + labels = labels[..., 1:].contiguous() + summed_loss = F.cross_entropy(logits, labels.view(-1), reduction="sum") else: - reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) with torch.no_grad(): - probs = torch.nn.functional.softmax(reshaped_shift_logits.detach(), dim=-1) + probs = torch.nn.functional.softmax(logits.detach(), dim=-1) sampled_labels = torch.multinomial( probs, num_samples=1, ).flatten() - summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, reduction="sum") + summed_loss = F.cross_entropy(logits, sampled_labels, reduction="sum") return summed_loss def compute_measurement( @@ -117,7 +115,7 @@ def compute_measurement( # We could also compute the log-likelihood or averaged margin. return self.compute_train_loss(batch, model) - def tracked_modules(self) -> List[str]: + def get_influence_tracked_modules(self) -> List[str]: total_modules = [] for i in range(12): @@ -130,7 +128,7 @@ def tracked_modules(self) -> List[str]: return total_modules - def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: + def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: return batch["attention_mask"] @@ -183,8 +181,8 @@ def main(): dataset=train_dataset, per_device_batch_size=None, factor_args=factor_args, - overwrite_output_dir=False, initial_per_device_batch_size_attempt=64, + overwrite_output_dir=False, ) # Compute pairwise scores. @@ -197,15 +195,14 @@ def main(): scores_name += "_compile" rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None if rank is not None: - score_args.query_gradient_rank = rank - score_args.num_query_gradient_accumulations = 10 + score_args.query_gradient_low_rank = rank + score_args.query_gradient_accumulation_steps = 10 scores_name += f"_qlr{rank}" analyzer.compute_pairwise_scores( scores_name=scores_name, score_args=score_args, factors_name=factors_name, query_dataset=eval_dataset, - query_indices=list(range(min([len(eval_dataset), 2000]))), train_dataset=train_dataset, per_device_query_batch_size=args.query_batch_size, per_device_train_batch_size=args.train_batch_size, diff --git a/examples/wikitext/evaluate_lds.py b/examples/wikitext/evaluate_lds.py index 616aa46..654dbb8 100644 --- a/examples/wikitext/evaluate_lds.py +++ b/examples/wikitext/evaluate_lds.py @@ -37,7 +37,7 @@ def main(): # scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half/pairwise_scores.safetensors")[ # "all_modules" # ].to(dtype=torch.float32) - # scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half_compile/pairwise_scores.safetensors")[ + # scores = Analyzer.load_file("influence_results/wikitext/scores_identity/pairwise_scores.safetensors")[ # "all_modules" # ].to(dtype=torch.float32) diff --git a/examples/wikitext/half_precision_analysis.py b/examples/wikitext/half_precision_analysis.py index 8449f50..21cfeb1 100644 --- a/examples/wikitext/half_precision_analysis.py +++ b/examples/wikitext/half_precision_analysis.py @@ -18,9 +18,6 @@ def main(): half_scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half/pairwise_scores.safetensors")[ "all_modules" ].float() - # half_scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half_compile/pairwise_scores.safetensors")[ - # "all_modules" - # ].float() # Only plot first 1000 points to avoid clutter. index = 5 diff --git a/examples/wikitext/run_counterfactual.py b/examples/wikitext/run_counterfactual.py index c479ca6..e91e189 100644 --- a/examples/wikitext/run_counterfactual.py +++ b/examples/wikitext/run_counterfactual.py @@ -34,10 +34,9 @@ def get_topk_keep_indices(current_score: torch.Tensor, topk: int = 1) -> List[in valid_dataset = get_wikitext_dataset(split="valid", indices=list(range(50))) - def train_and_evaluate(indices): - train_dataset = get_wikitext_dataset(split="train", indices=indices) + def train_and_evaluate(indices) -> float: model = train( - dataset=train_dataset, + dataset=get_wikitext_dataset(split="train", indices=indices), batch_size=8, num_train_epochs=3, learning_rate=3e-05, diff --git a/kronfluence/analyzer.py b/kronfluence/analyzer.py index f3582f7..615c1c0 100644 --- a/kronfluence/analyzer.py +++ b/kronfluence/analyzer.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Dict, Optional, Union @@ -20,33 +21,32 @@ def prepare_model( model: nn.Module, task: Task, ) -> nn.Module: - """Prepares the model before passing it to `Analyzer`. This function sets `param.requires_grad = False` - for all modules and installs `TrackedModule` to supported modules. This `TrackedModule` keeps track of relevant - statistics needed to compute influence scores. + """Prepares the model for analysis by setting all parameters and buffers to non-trainable + and installing `TrackedModule` wrappers on supported modules. Args: model (nn.Module): - The PyTorch model to be analyzed. + The PyTorch model to be prepared for analysis. task (Task): - The specific task associated with the model. + The specific task associated with the model, used for `TrackedModule` installation. Returns: nn.Module: - The PyTorch model with `param.requires_grad = False` on all modules and with `TrackedModule` installed. + The prepared model with non-trainable parameters and `TrackedModule` wrappers. """ model.eval() for params in model.parameters(): params.requires_grad = False for buffers in model.buffers(): buffers.requires_grad = False - # Install `TrackedModule` to the model. + + # Install `TrackedModule` wrappers on supported modules. model = wrap_tracked_modules(model=model, task=task) return model class Analyzer(FactorComputer, ScoreComputer): - """Handles the computation of all factors (e.g., covariance and Lambda matrices for EKFAC) - and influence scores for a given PyTorch model.""" + """Handles the computation of factors (e.g., covariance matrices) and scores for a given PyTorch model.""" def __init__( self, @@ -61,33 +61,33 @@ def __init__( output_dir: str = "./influence_results", disable_model_save: bool = True, ) -> None: - """Initializes an instance of the Analyzer class. + """Initializes an instance of the `Analyzer` class. Args: analysis_name (str): - The unique identifier for the analysis, used to organize and retrieve the results. + Unique identifier for the analysis, used for organizing results. model (nn.Module): The PyTorch model to be analyzed. task (Task): The specific task associated with the model. cpu (bool, optional): - Specifies whether the analysis should be explicitly performed using the CPU. - Defaults to False, utilizing GPU resources if available. + If `True`, forces analysis to be performed on CPU. Defaults to `False`. log_level (int, optional): - The logging level to use (e.g., logging.DEBUG, logging.INFO). Defaults to the root logging level. + Logging level (e.g., logging.DEBUG, logging.INFO). Defaults to root logging level. log_main_process_only (bool, optional): - If True, restricts logging to the main process. Defaults to True. + If `True`, restricts logging to the main process. Defaults to `True`. profile (bool, optional): - Enables the generation of performance profiling logs. This can be useful for - identifying bottlenecks or performance issues. Defaults to False. + If `True`, enables performance profiling logs. Defaults to `False`. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + If `True`, disables TQDM progress bars. Defaults to `False`. output_dir (str): - The file path to the directory, where analysis results will be stored. If the directory - does not exist, it will be created. Defaults to './influence_results'. + Directory path for storing analysis results. Defaults to './influence_results'. disable_model_save (bool, optional): - If set to True, prevents the saving of the model's state_dict. When the provided model is different - from the previously saved model, it will raise an Exception. Defaults to True. + If `True`, prevents saving the model's `state_dict`. Defaults to `True`. + + Raises: + ValueError: + If the provided model differs from a previously saved model when `disable_model_save=False`. """ super().__init__( name=analysis_name, @@ -100,20 +100,20 @@ def __init__( disable_tqdm=disable_tqdm, output_dir=output_dir, ) - self.logger.info(f"Initializing Computer with parameters: {locals()}") - self.logger.debug(f"Process state configuration:\n{repr(self.state)}") + self.logger.info(f"Initializing `Analyzer` with parameters: {locals()}") + self.logger.info(f"Process state configuration:\n{repr(self.state)}") - # Saves model parameters. + # Save model parameters if necessary. if self.state.is_main_process and not disable_model_save: self._save_model() self.state.wait_for_everyone() def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None: - """Sets the default DataLoader parameters to use for all DataLoaders. + """Sets the default DataLoader arguments. Args: dataloader_kwargs (DataLoaderKwargs): - The object containing parameters for DataLoader. + The object containing arguments for PyTorch DataLoader. """ self._dataloader_params = dataloader_kwargs @@ -121,17 +121,16 @@ def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None: def _save_model(self) -> None: """Saves the model to the output directory.""" model_save_path = self.output_dir / "model.safetensors" - extracted_model = extract_model_from_parallel(model=self.model, keep_fp32_wrapper=True) + extracted_model = extract_model_from_parallel(model=copy.deepcopy(self.model), keep_fp32_wrapper=True) if model_save_path.exists(): self.logger.info(f"Found existing saved model at `{model_save_path}`.") - # Load the existing model's state_dict for comparison. + # Load existing model's `state_dict` for comparison. loaded_state_dict = load_file(model_save_path) if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()): error_msg = ( "Detected a difference between the current model and the one saved at " - f"`{model_save_path}`. Consider using a different `analysis_name` to " - f"avoid conflicts." + f"`{model_save_path}`. Consider using a different `analysis_name` to avoid conflicts." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -152,27 +151,24 @@ def fit_all_factors( factor_args: Optional[FactorArguments] = None, overwrite_output_dir: bool = False, ) -> None: - """Computes all necessary factors for the given factor strategy. As an example, EK-FAC - requires (1) computing covariance matrices, (2) performing Eigendecomposition, and - (3) computing Lambda (corrected eigenvalues) matrices. + """Computes all necessary factors for the given strategy. Args: factors_name (str): - The unique identifier for the factor, used to organize and retrieve the results. + Unique identifier for the factor, used for organizing results. dataset (data.Dataset): - The dataset that will be used to fit all the factors. + Dataset used to fit all influence factors. per_device_batch_size (int, optional): - The per-device batch size used to fit the factors. If not specified, executable - per-device batch size is automatically determined. + Per-device batch size for factor fitting. If not specified, executable per-device batch size + is automatically determined. initial_per_device_batch_size_attempt (int): - The initial attempted per-device batch size when the batch size is not provided. + Initial batch size attempt when `per_device_batch_size` is not explicitly provided. Defaults to `4096`. dataloader_kwargs (DataLoaderKwargs, optional): - Controls additional arguments for PyTorch's DataLoader. + Additional arguments for PyTorch's DataLoader. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, - the default values of `FactorArguments` will be used. + Arguments for factor computation. Defaults to `FactorArguments` default values. overwrite_output_dir (bool, optional): - If True, the existing factors with the same name will be overwritten. + If `True`, overwrites existing factors with the same `factors_name`. Defaults to `False`. """ self.fit_covariance_matrices( factors_name=factors_name, @@ -200,36 +196,41 @@ def fit_all_factors( @staticmethod def load_file(path: Union[str, Path]) -> Dict[str, torch.Tensor]: - """Loads the `.safetensors` file at the given path from disk. - - See https://github.com/huggingface/safetensors. + """Loads a `safetensors` file from the given path. Args: path (Path): - The path to the `.safetensors` file. + The path to the `safetensors` file. Returns: Dict[str, torch.Tensor]: - The contents of the file, which is the dictionary mapping string to tensors. + Dictionary mapping strings to tensors from the loaded file. + + Raises: + FileNotFoundError: + If the specified file does not exist. + + Note: + For more information on `safetensors`, see https://github.com/huggingface/safetensors. """ if isinstance(path, str): path = Path(path).resolve() if not path.exists(): - raise FileNotFoundError(f"File does not exists at `{path}`.") + raise FileNotFoundError(f"File not found: {path}.") return load_file(path) @staticmethod def get_module_summary(model: nn.Module) -> str: - """Returns the formatted summary of the modules in model. Useful identifying which modules to - compute influence scores. + """Generates a formatted summary of the model's modules, excluding those without parameters. This summary is + useful for identifying which modules to compute influence scores for. Args: model (nn.Module): - The PyTorch model to be investigated. + The PyTorch model to be summarized. Returns: str: - The formatted string summary of the model. + A formatted string containing the model summary, including module names and representations. """ format_str = "==Model Summary==" for module_name, module in model.named_modules(): diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index db3978d..242688d 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -10,7 +10,12 @@ class Arguments: """Base class for specifying arguments for computing factors and influence scores.""" def to_dict(self) -> Dict[str, Any]: - """Converts the arguments to a dictionary.""" + """Converts the arguments to a dictionary. + + Returns: + Dict[str, Any]: + A dictionary representation of the arguments, with `torch.dtype` values converted to strings. + """ config = copy.deepcopy(self.__dict__) for key, value in config.items(): if isinstance(value, torch.dtype): @@ -18,7 +23,12 @@ def to_dict(self) -> Dict[str, Any]: return config def to_str_dict(self) -> Dict[str, str]: - """Converts the arguments to a dictionary, where all values are converted to strings.""" + """Converts the arguments to a dictionary with all values as strings. + + Returns: + Dict[str, str]: + A dictionary representation of the arguments, with all values converted to strings. + """ config = copy.deepcopy(self.__dict__) for key, value in config.items(): config[key] = str(value) @@ -29,196 +39,228 @@ def to_str_dict(self) -> Dict[str, str]: class FactorArguments(Arguments): """Arguments for computing influence factors.""" - # General configuration. # + # General configuration # strategy: str = field( default="ekfac", - metadata={"help": "Strategy for computing preconditioning factors."}, + metadata={"help": "Specifies the algorithm for computing influence factors. Default is 'ekfac'."}, ) use_empirical_fisher: bool = field( default=False, metadata={ - "help": "Whether to use empirical Fisher (using labels from batch) instead of " - "true Fisher (using sampled labels)." - }, - ) - distributed_sync_steps: int = field( - default=1_000, - metadata={ - "help": "Specifies the total iteration step to synchronize the process when using distributed setting." + "help": "If `True`, approximates empirical Fisher (using true labels) instead of " + "true Fisher (using sampled labels from model's outputs)." }, ) amp_dtype: Optional[torch.dtype] = field( default=None, - metadata={"help": "Dtype for automatic mixed precision (AMP). Disables AMP if None."}, + metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."}, ) - shared_parameters_exist: bool = field( + has_shared_parameters: bool = field( default=False, - metadata={"help": "Specifies whether the shared parameters exist in the forward pass."}, + metadata={"help": "Indicates whether shared parameters are present in the model's forward pass."}, ) # Configuration for fitting covariance matrices. # covariance_max_examples: Optional[int] = field( default=100_000, - metadata={ - "help": "Maximum number of examples for fitting covariance matrices. " - "Uses all data examples for the given dataset if None." - }, + metadata={"help": "Maximum number of examples for fitting covariance matrices. Uses entire dataset if `None`."}, ) - covariance_data_partition_size: int = field( + covariance_data_partitions: int = field( default=1, - metadata={ - "help": "Number of data partitions for computing covariance matrices. " - "For example, when `covariance_data_partition_size = 2`, the dataset is split " - "into 2 chunks and covariance matrices are separately computed for each chunk." - }, + metadata={"help": "Number of partitions to divide the dataset into for covariance matrix computation."}, ) - covariance_module_partition_size: int = field( + covariance_module_partitions: int = field( default=1, metadata={ - "help": "Number of module partitions for computing covariance matrices. " - "For example, when `covariance_module_partition_size = 2`, the modules (layers) are split " - "into 2 chunks and covariance matrices are separately computed for each chunk." + "help": "Number of partitions to divide the model's modules (layers) into for " + "covariance matrix computation." }, ) activation_covariance_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing activation covariance matrices."}, + metadata={"help": "Data type for activation covariance computation."}, ) gradient_covariance_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing pseudo-gradient covariance matrices."}, + metadata={"help": "Data type for pseudo-gradient covariance computation."}, ) - # Configuration for performing eigendecomposition. # + # Configuration for performing eigendecomposition # eigendecomposition_dtype: torch.dtype = field( default=torch.float64, - metadata={"help": "Dtype for performing Eigendecomposition. Recommended to use `torch.float64`."}, + metadata={ + "help": "Data type for eigendecomposition. Double precision (`torch.float64`) is recommended " + "for numerical stability." + }, ) - # Configuration for fitting Lambda matrices. # + # Configuration for fitting Lambda matrices # lambda_max_examples: Optional[int] = field( default=100_000, - metadata={ - "help": "Maximum number of examples for fitting Lambda matrices. " - "Uses all data examples for the given dataset if None." - }, + metadata={"help": "Maximum number of examples for fitting Lambda matrices. Uses entire dataset if `None`."}, ) - lambda_data_partition_size: int = field( + lambda_data_partitions: int = field( default=1, - metadata={ - "help": "Number of data partitions for computing Lambda matrices. " - "For example, when `lambda_data_partition_size = 2`, the dataset is split " - "into 2 chunks and Lambda matrices are separately computed for each chunk." - }, + metadata={"help": "Number of partitions to divide the dataset into for Lambda matrix computation."}, ) - lambda_module_partition_size: int = field( + lambda_module_partitions: int = field( default=1, metadata={ - "help": "Number of module partitions for computing Lambda matrices. " - "For example, when `lambda_module_partition_size = 2`, the modules (layers) are split " - "into 2 chunks and Lambda matrices are separately computed for each chunk." + "help": "Number of partitions to divide the model's modules (layers) into for Lambda matrix computation." }, ) - lambda_iterative_aggregate: bool = field( + use_iterative_lambda_aggregation: bool = field( default=False, metadata={ - "help": "Whether to aggregate squared sum of projected per-sample-gradient with for-loop iterations." + "help": "If `True`, aggregates the squared sum of projected per-sample gradients " + "iteratively (with for-loop) to reduce GPU memory usage." }, ) - cached_activation_cpu_offload: bool = field( + offload_activations_to_cpu: bool = field( default=False, - metadata={"help": "Whether to offload cached activation to CPU for computing the per-sample-gradient."}, + metadata={"help": "If `True`, offloads cached activations to CPU memory when computing per-sample gradients."}, ) per_sample_gradient_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing per-sample-gradients."}, + metadata={"help": "Data type for per-sample gradient computation."}, ) lambda_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing Lambda (corrected eigenvalues) matrices."}, + metadata={"help": "Data type for Lambda matrix computation."}, ) + def __post_init__(self) -> None: + if self.covariance_max_examples is not None and self.covariance_max_examples <= 0: + raise ValueError("`covariance_max_examples` must be `None` or positive.") + + if self.lambda_max_examples is not None and self.lambda_max_examples <= 0: + raise ValueError("`lambda_max_examples` must be `None` or positive.") + + if any( + partition <= 0 + for partition in [ + self.covariance_data_partitions, + self.covariance_module_partitions, + self.lambda_data_partitions, + self.lambda_module_partitions, + ] + ): + raise ValueError("All data and module partitions must be positive.") + @dataclass class ScoreArguments(Arguments): """Arguments for computing influence scores.""" - # General configuration. # - damping: Optional[float] = field( + # General configuration # + damping_factor: Optional[float] = field( default=1e-08, metadata={ - "help": "A damping factor for the damped inverse Hessian-vector product (iHVP). " - "Uses a heuristic based on mean eigenvalues (0.1 x mean eigenvalues) if set to None." - }, - ) - cached_activation_cpu_offload: bool = field( - default=False, - metadata={"help": "Whether to offload cached activation to CPU for computing the per-sample-gradient."}, - ) - distributed_sync_steps: int = field( - default=1_000, - metadata={ - "help": "Specifies the total iteration step to synchronize the process when using distributed setting." + "help": "Damping factor for inverse Hessian-vector product (iHVP). " + "If `None`, uses a heuristic of 0.1 times the mean eigenvalue." }, ) amp_dtype: Optional[torch.dtype] = field( default=None, - metadata={"help": "Dtype for automatic mixed precision (AMP). Disables AMP if None."}, + metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."}, + ) + offload_activations_to_cpu: bool = field( + default=False, + metadata={"help": "If `True`, offloads cached activations to CPU memory when computing per-sample gradients."}, ) - # Partition configuration. # - data_partition_size: int = field( + # Partition configuration # + data_partitions: int = field( default=1, - metadata={ - "help": "Number of data partitions for computing influence scores. For example, when " - "`data_partition_size = 2`, the dataset is split into 2 chunks and scores are separately " - "computed for each chunk." - }, + metadata={"help": "Number of partitions to divide the dataset for influence score computation."}, ) - module_partition_size: int = field( + module_partitions: int = field( default=1, metadata={ - "help": "Number of module partitions for computing influence scores. For example, when " - "`module_partition_size = 2`, the module (layers) are split into 2 modules and scores are separately " - "computed for each chunk." + "help": "Number of partitions to divide the model's modules (layers) into for influence score computation." }, ) - # Score configuration. # - per_module_score: bool = field( + # General score configuration # + compute_per_module_scores: bool = field( + default=False, + metadata={"help": "If `True`, computes separate scores for each module instead of summing across all."}, + ) + compute_per_token_scores: bool = field( default=False, metadata={ - "help": "Whether to obtain per-module scores instead of the summed scores across all modules. " - "This is useful when performing layer-wise influence analysis." + "help": "If `True`, computes separate scores for each token instead of summing across all. " + "Only applicable to transformer-based models." }, ) - num_query_gradient_accumulations: int = field( + + # Pairwise influence score configuration # + query_gradient_accumulation_steps: int = field( default=1, - metadata={"help": "Number of query batches to accumulate over before iterating over training examples."}, + metadata={"help": "Number of query batches to accumulate before processing training examples."}, ) - query_gradient_rank: Optional[int] = field( + query_gradient_low_rank: Optional[int] = field( default=None, - metadata={"help": "Rank for the query gradient. Does not apply low-rank approximation if None."}, + metadata={ + "help": "Rank for the low-rank approximation of the query gradient (query batching). " + "If `None`, no low-rank approximation is applied." + }, + ) + use_full_svd: bool = field( + default=False, + metadata={ + "help": "If `True`, uses `torch.linalg.svd` instead of `torch.svd_lowrank` for query batching. " + "Full SVD is more accurate but slower and more memory-intensive." + }, ) + aggregate_query_gradients: bool = field( + default=False, + metadata={ + "help": "If `True`, uses the summed query gradient instead of per-sample query gradients " + "for pairwise influence computation." + }, + ) + aggregate_train_gradients: bool = field( + default=False, + metadata={ + "help": "If `True`, uses the summed training gradient instead of per-sample training gradients " + "for pairwise influence computation." + }, + ) + + # Self-influence score configuration # use_measurement_for_self_influence: bool = field( default=False, - metadata={"help": "Whether to use the measurement (instead of the loss) for computing self-influence scores."}, + metadata={"help": "If `True`, uses the measurement (instead of the loss) for computing self-influence scores."}, ) - # Dtype configuration. # + # Precision configuration # query_gradient_svd_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for performing singular value decomposition (SVD) on the query gradient."}, - ) - score_dtype: torch.dtype = field( - default=torch.float32, - metadata={"help": "Dtype for computing and storing influence scores."}, + metadata={"help": "Data type for singular value decomposition (SVD) of query gradient."}, ) per_sample_gradient_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing per-sample-gradients."}, + metadata={"help": "Data type for query per-sample gradient computation."}, ) precondition_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing the preconditioned gradient. Recommended to use `torch.float32`."}, + metadata={"help": "Data type for preconditioned gradient computation."}, ) + score_dtype: torch.dtype = field( + default=torch.float32, + metadata={"help": "Data type for influence score computation."}, + ) + + def __post_init__(self) -> None: + if self.damping_factor is not None and self.damping_factor < 0: + raise ValueError("`damping_factor` must be `None` or positive.") + + if any(partition <= 0 for partition in [self.data_partitions, self.module_partitions]): + raise ValueError("Both data and module partitions must be positive.") + + if self.query_gradient_accumulation_steps <= 0: + raise ValueError("`query_gradient_accumulation_steps` must be positive.") + + if self.query_gradient_low_rank is not None and self.query_gradient_low_rank <= 0: + raise ValueError("`query_gradient_low_rank` must be `None` or positive.") diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index 74c8bb3..876e6f2 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -23,11 +23,23 @@ load_eigendecomposition, load_lambda_matrices, ) -from kronfluence.module.utils import get_tracked_module_names, make_modules_partition +from kronfluence.module.tracked_module import ModuleMode +from kronfluence.module.utils import ( + get_tracked_module_names, + make_modules_partition, + set_mode, +) from kronfluence.score.pairwise import load_pairwise_scores, pairwise_scores_exist from kronfluence.score.self import load_self_scores, self_scores_exist from kronfluence.task import Task -from kronfluence.utils.constants import FACTOR_TYPE, SCORE_TYPE +from kronfluence.utils.constants import ( + FACTOR_ARGUMENTS_NAME, + FACTOR_SAVE_PREFIX, + FACTOR_TYPE, + SCORE_ARGUMENTS_NAME, + SCORE_SAVE_PREFIX, + SCORE_TYPE, +) from kronfluence.utils.dataset import ( DataLoaderKwargs, DistributedEvalSampler, @@ -39,16 +51,8 @@ TrackedModuleNotFoundError, ) from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger -from kronfluence.utils.model import apply_ddp -from kronfluence.utils.save import ( - FACTOR_ARGUMENTS_NAME, - FACTOR_SAVE_PREFIX, - SCORE_ARGUMENTS_NAME, - SCORE_SAVE_PREFIX, - load_json, - save_json, -) -from kronfluence.utils.state import State +from kronfluence.utils.save import load_json, save_json +from kronfluence.utils.state import State, release_memory class Computer(ABC): @@ -66,10 +70,10 @@ def __init__( profile: bool = False, disable_tqdm: bool = False, ) -> None: - """Initializes an instance of the Computer class.""" + """Initializes an instance of the `Computer` class. See `Analyzer` for more information.""" self.state = State(cpu=cpu) - # Creates and configures logger. + # Create and configure logger. disable_log = log_main_process_only and self.state.process_index != 0 self.logger = get_logger(name=__name__, log_level=log_level, disable_log=disable_log) @@ -80,47 +84,44 @@ def __init__( if len(tracked_module_names) == 0: error_msg = ( f"No tracked modules found in the provided model: {self.model}. " - f"Please make sure to run `prepare_model` before passing it in to the " - f"Analyzer." + f"Please ensure you've run `prepare_model` before passing it to the Analyzer." ) self.logger.error(error_msg) raise TrackedModuleNotFoundError(error_msg) self.logger.info(f"Tracking modules with names: {tracked_module_names}.") if self.state.use_distributed and not isinstance(model, (DDP, FSDP)): - warning_msg = ( - "Creating a DDP module. If specific configuration needs to be used " - "for DDP, please pass in the model after the manual DDP wrapping." + self.logger.warning( + "Creating a DDP module. For custom DDP configuration, " + "please manually wrap the model with DDP before passing it in." ) - self.logger.warning(warning_msg) self.model.to(self.state.device) - self.model = apply_ddp( - model=self.model, - local_rank=self.state.local_process_index, - rank=self.state.process_index, - world_size=self.state.num_processes, + self.model = DDP( + self.model, + device_ids=[self.state.local_process_index], + output_device=self.state.local_process_index, ) if cpu and isinstance(model, (DataParallel, DDP, FSDP)): - error_msg = "To enforce CPU, the model should not be wrapped with DP, DDP, or FSDP." + error_msg = ( + "CPU enforcement is incompatible with DP, DDP, or FSDP wrapped models. " + "Please provide an unwrapped model when using `cpu=True`." + ) self.logger.error(error_msg) raise ValueError(error_msg) if not self.state.use_distributed: self.model.to(self.state.device) - # Creates and configures output directory. + # Create and configure output directory. self.output_dir = Path(output_dir).joinpath(name).resolve() os.makedirs(name=self.output_dir, exist_ok=True) - # Creates and configures profiler. + # Create and configure profiler. self.profiler = Profiler(state=self.state) if profile else PassThroughProfiler(state=self.state) - # Creates directory to save profiler output. - self.profiler_dir = (self.output_dir / "profiler_output").resolve() - os.makedirs(name=self.profiler_dir, exist_ok=True) self.disable_tqdm = disable_tqdm - # Sets PyTorch DataLoader arguments. + # Set PyTorch DataLoader arguments. self._dataloader_params = DataLoaderKwargs() def factors_output_dir(self, factors_name: str) -> Path: @@ -145,11 +146,11 @@ def _save_arguments( loaded_arguments = load_json(arguments_save_path) if loaded_arguments != arguments.to_dict(): error_msg = ( - "Attempting to use the arguments that differs from the one already saved. " - "Please set `overwrite_output_dir=True` to overwrite existing experiment." + f"New arguments differ from saved arguments at `{arguments_save_path}`. " + "Set `overwrite_output_dir=True` to overwrite existing experiment.\n" + f"New arguments: {arguments.to_dict()}\n" + f"Saved arguments: {loaded_arguments}" ) - error_msg += f"\nNew arguments: {arguments.to_dict()}." - error_msg += f"\nSaved arguments: {loaded_arguments}." self.logger.error(error_msg) raise ValueError(error_msg) else: @@ -198,13 +199,13 @@ def _get_dataloader( allow_duplicates: bool = False, stack: bool = False, ) -> data.DataLoader: - """Returns the DataLoader for the given dataset, per_device_batch_size, and additional parameters.""" + """Returns the DataLoader with the provided configuration.""" if indices is not None: dataset = data.Subset(dataset=dataset, indices=indices) if self.state.use_distributed and not allow_duplicates: if stack: - error_msg = "DistributedEvalSampler is not currently supported with `stack=True`." + error_msg = "DistributedEvalSampler is incompatible with `stack=True`." self.logger.error(error_msg) raise ValueError(error_msg) sampler = DistributedEvalSampler( @@ -249,35 +250,31 @@ def _configure_dataloader(self, dataloader_kwargs: DataLoaderKwargs) -> Dict[str def _get_data_partition( self, total_data_examples: int, - data_partition_size: int, + data_partitions: int, target_data_partitions: Optional[Union[int, List[int]]], ) -> Tuple[List[Tuple[int, int]], List[int]]: """Partitions the dataset into several chunks.""" - if total_data_examples < data_partition_size: + if total_data_examples < data_partitions: error_msg = ( - f"Data partition size ({data_partition_size}) cannot be greater than the " - f"total data points ({total_data_examples}). Please reduce the data partition " - f"size in the argument." + f"Data partition size ({data_partitions}) exceeds total data points ({total_data_examples}). " + "Please reduce the data partition size." ) self.logger.error(error_msg) raise ValueError(error_msg) indices_partitions = make_indices_partition( - total_data_examples=total_data_examples, partition_size=data_partition_size + total_data_examples=total_data_examples, partition_size=data_partitions ) if target_data_partitions is None: - target_data_partitions = list(range(data_partition_size)) + target_data_partitions = list(range(data_partitions)) if isinstance(target_data_partitions, int): target_data_partitions = [target_data_partitions] for data_partition in target_data_partitions: - if data_partition < 0 or data_partition > data_partition_size: - error_msg = ( - f"Invalid data partition {data_partition} encountered. " - f"The module partition needs to be in between [0, {data_partition_size})." - ) + if data_partition < 0 or data_partition > data_partitions: + error_msg = f"Invalid data partition {data_partition}. Must be in range [0, {data_partitions})." self.logger.error(error_msg) raise ValueError(error_msg) @@ -285,49 +282,53 @@ def _get_data_partition( def _get_module_partition( self, - module_partition_size: int, + module_partitions: int, target_module_partitions: Optional[Union[int, List[int]]], ) -> Tuple[List[List[str]], List[int]]: """Partitions the modules into several chunks.""" tracked_module_names = get_tracked_module_names(self.model) - if len(tracked_module_names) < module_partition_size: + if len(tracked_module_names) < module_partitions: error_msg = ( - f"Module partition size ({module_partition_size}) cannot be greater than the " - f"total tracked modules ({len(tracked_module_names)}). Please reduce the module partition " - f"size in the argument." + f"Module partition size ({module_partitions}) exceeds total tracked modules " + f"({len(tracked_module_names)}). Please reduce the module partition size." ) self.logger.error(error_msg) raise ValueError(error_msg) modules_partition_list = make_modules_partition( total_module_names=tracked_module_names, - partition_size=module_partition_size, + partition_size=module_partitions, ) if target_module_partitions is None: - target_module_partitions = list(range(module_partition_size)) + target_module_partitions = list(range(module_partitions)) if isinstance(target_module_partitions, int): target_module_partitions = [target_module_partitions] for module_partition in target_module_partitions: - if module_partition < 0 or module_partition > module_partition_size: - error_msg = ( - f"Invalid module partition {module_partition} encountered. " - f"The module partition needs to be in between [0, {module_partition_size})." - ) + if module_partition < 0 or module_partition > module_partitions: + error_msg = f"Invalid module partition {module_partition}. Must be in range [0, {module_partitions})." self.logger.error(error_msg) raise ValueError(error_msg) return modules_partition_list, target_module_partitions - def _log_profile_summary(self) -> None: + def _reset_memory(self) -> None: + """Clears all cached memory.""" + self.model.zero_grad(set_to_none=True) + set_mode(model=self.model, mode=ModuleMode.DEFAULT, release_memory=True) + release_memory() + + def _log_profile_summary(self, name: str) -> None: """Saves the summary of the profiling results.""" profile_summary = self.profiler.summary() time_str = time.strftime("%Y%m%d_%H%M%S") - profile_save_path = (self.profiler_dir / f"summary_rank_{self.state.process_index}_{time_str}.txt").resolve() + profiler_dir = (self.output_dir / "profiler_output").resolve() + profile_save_path = (profiler_dir / f"{name}_summary_rank_{self.state.process_index}_{time_str}.txt").resolve() if profile_summary != "": + os.makedirs(name=profiler_dir, exist_ok=True) self.logger.info(profile_summary) with open(profile_save_path, "a", encoding="utf-8") as f: f.write(profile_summary) @@ -394,7 +395,7 @@ def load_all_factors(self, factors_name: str) -> FACTOR_TYPE: if factor_args is None: error_msg = f"Factors with name `{factors_name}` was not found at `{factors_output_dir}`." self.logger.error(error_msg) - raise ValueError(error_msg) + raise FileNotFoundError(error_msg) loaded_factors: FACTOR_TYPE = {} factor_config = FactorConfig.CONFIGS[factor_args.strategy] diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py index 03f8a17..5b7b1d4 100644 --- a/kronfluence/computer/factor_computer.py +++ b/kronfluence/computer/factor_computer.py @@ -25,14 +25,10 @@ save_eigendecomposition, save_lambda_matrices, ) -from kronfluence.module.tracked_module import ModuleMode -from kronfluence.module.utils import set_mode -from kronfluence.utils.constants import FACTOR_TYPE +from kronfluence.utils.constants import FACTOR_ARGUMENTS_NAME, FACTOR_TYPE from kronfluence.utils.dataset import DataLoaderKwargs, find_executable_batch_size from kronfluence.utils.exceptions import FactorsNotFoundError from kronfluence.utils.logger import get_time -from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME -from kronfluence.utils.state import release_memory class FactorComputer(Computer): @@ -41,7 +37,7 @@ class FactorComputer(Computer): def _configure_and_save_factor_args( self, factor_args: Optional[FactorArguments], factors_output_dir: Path, overwrite_output_dir: bool ) -> FactorArguments: - """Configures the provided factor arguments and saves it in disk.""" + """Configures and saves factor arguments to disk.""" if factor_args is None: factor_args = FactorArguments() self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") @@ -62,31 +58,31 @@ def _configure_and_save_factor_args( def _aggregate_factors( self, factors_name: str, - data_partition_size: int, - module_partition_size: int, - exists_fnc: Callable, + data_partitions: int, + module_partitions: int, + exist_fnc: Callable, load_fnc: Callable, save_fnc: Callable, ) -> Optional[FACTOR_TYPE]: """Aggregates factors computed for all data and module partitions.""" factors_output_dir = self.factors_output_dir(factors_name=factors_name) if not factors_output_dir.exists(): - error_msg = f"Factors directory `{factors_output_dir}` is not found when trying to aggregate factors." + error_msg = f"Factors directory `{factors_output_dir}` not found when trying to aggregate factors." self.logger.error(error_msg) raise FileNotFoundError(error_msg) - all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] + all_required_partitions = [(i, j) for i in range(data_partitions) for j in range(module_partitions)] all_partition_exists = all( - exists_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions + exist_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions ) if not all_partition_exists: - self.logger.warning("Factors are not aggregated as factors for some partitions are not yet computed.") + self.logger.info("Factors are not aggregated as factors for some partitions are not yet computed.") return start_time = time.time() aggregated_factors: FACTOR_TYPE = {} - for data_partition in range(data_partition_size): - for module_partition in range(module_partition_size): + for data_partition in range(data_partitions): + for module_partition in range(module_partitions): loaded_factors = load_fnc( output_dir=factors_output_dir, partition=(data_partition, module_partition), @@ -97,9 +93,11 @@ def _aggregate_factors( for module_name in factors: if module_name not in aggregated_factors[factor_name]: - aggregated_factors[factor_name][module_name] = factors[module_name] - else: - aggregated_factors[factor_name][module_name].add_(factors[module_name]) + aggregated_factors[factor_name][module_name] = torch.zeros_like( + factors[module_name], + requires_grad=False, + ) + aggregated_factors[factor_name][module_name].add_(factors[module_name]) del loaded_factors save_fnc( output_dir=factors_output_dir, @@ -121,7 +119,7 @@ def _find_executable_factors_batch_size( """Automatically finds executable batch size for performing `func`.""" if self.state.use_distributed: error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " + "Automatic batch size search is not supported for multi-GPU setting. " "Please manually configure the batch size by passing in `per_device_batch_size`." ) self.logger.error(error_msg) @@ -140,9 +138,7 @@ def _find_executable_factors_batch_size( def executable_batch_size_func(batch_size: int) -> None: self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") # Releases all memory that could be caused by the previous OOM. - self.model.zero_grad(set_to_none=True) - set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) - release_memory() + self._reset_memory() total_batch_size = batch_size * self.state.num_processes loader = self._get_dataloader( dataset=dataset, @@ -178,7 +174,7 @@ def fit_covariance_matrices( factors_name (str): The unique identifier for the factor, used to organize and retrieve the results. dataset (data.Dataset): - The dataset that will be used to fit covariance matrices. + The dataset that will be used for fitting covariance matrices. per_device_batch_size (int, optional): The per-device batch size used to fit the factors. If not specified, executable batch size is automatically determined. @@ -187,8 +183,7 @@ def fit_covariance_matrices( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, the default values of - `FactorArguments` will be used. + Arguments for factor computation. target_data_partitions(Sequence[int], int, optional): The list of data partition to fit covariance matrices. By default, covariance matrices will be computed for all partitions. @@ -196,7 +191,7 @@ def fit_covariance_matrices( The list of module partition to fit covariance matrices. By default, covariance matrices will be computed for all partitions. overwrite_output_dir (bool, optional): - If True, the existing factors with the same `factors_name` will be overwritten. + Whether to overwrite existing output. """ self.logger.debug(f"Fitting covariance matrices with parameters: {locals()}") @@ -231,9 +226,7 @@ def fit_covariance_matrices( total_data_examples = min([factor_args.covariance_max_examples, len(dataset)]) self.logger.info(f"Total data examples to fit covariance matrices: {total_data_examples}.") - no_partition = ( - factor_args.covariance_data_partition_size == 1 and factor_args.covariance_module_partition_size == 1 - ) + no_partition = factor_args.covariance_data_partitions == 1 and factor_args.covariance_module_partitions == 1 partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -245,17 +238,17 @@ def fit_covariance_matrices( data_partition_indices, target_data_partitions = self._get_data_partition( total_data_examples=total_data_examples, - data_partition_size=factor_args.covariance_data_partition_size, + data_partitions=factor_args.covariance_data_partitions, target_data_partitions=target_data_partitions, ) - max_partition_examples = total_data_examples // factor_args.covariance_data_partition_size + max_partition_examples = total_data_examples // factor_args.covariance_data_partitions module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=factor_args.covariance_module_partition_size, + module_partitions=factor_args.covariance_module_partitions, target_module_partitions=target_module_partitions, ) if max_partition_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples. Try reducing the number of processes." + error_msg = "The number of processes are larger than total data examples. Try reducing number of processes." self.logger.error(error_msg) raise ValueError(error_msg) @@ -304,7 +297,7 @@ def fit_covariance_matrices( total_data_examples=max_partition_examples, ) - release_memory() + self._reset_memory() start_time = get_time(state=self.state) with self.profiler.profile("Fit Covariance"): loader = self._get_dataloader( @@ -339,8 +332,9 @@ def fit_covariance_matrices( metadata=factor_args.to_str_dict(), ) self.state.wait_for_everyone() - del covariance_factors, loader self.logger.info(f"Saved covariance matrices at `{factors_output_dir}`.") + del num_data_processed, covariance_factors, loader + self._reset_memory() all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time @@ -350,7 +344,7 @@ def fit_covariance_matrices( self.aggregate_covariance_matrices(factors_name=factors_name) self.logger.info(f"Saved aggregated covariance matrices at `{factors_output_dir}`.") self.state.wait_for_everyone() - self._log_profile_summary() + self._log_profile_summary(name=f"factors_{factors_name}_covariance") @torch.no_grad() def aggregate_covariance_matrices( @@ -368,7 +362,7 @@ def aggregate_covariance_matrices( if factor_args is None: error_msg = ( f"Arguments for factors with name `{factors_name}` was not found when trying to " - f"aggregated covariance matrices." + f"aggregate covariance matrices." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -376,9 +370,9 @@ def aggregate_covariance_matrices( with self.profiler.profile("Aggregate Covariance"): self._aggregate_factors( factors_name=factors_name, - data_partition_size=factor_args.covariance_data_partition_size, - module_partition_size=factor_args.covariance_module_partition_size, - exists_fnc=covariance_matrices_exist, + data_partitions=factor_args.covariance_data_partitions, + module_partitions=factor_args.covariance_module_partitions, + exist_fnc=covariance_matrices_exist, load_fnc=load_covariance_matrices, save_fnc=save_covariance_matrices, ) @@ -390,26 +384,25 @@ def perform_eigendecomposition( overwrite_output_dir: bool = False, load_from_factors_name: Optional[str] = None, ) -> None: - """Performs Eigendecomposition on all covariance matrices. + """Performs eigendecomposition on all covariance matrices. Args: factors_name (str): The unique identifier for the factor, used to organize and retrieve the results. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, the default values of - `FactorArguments` will be used. + Arguments for factor computation. overwrite_output_dir (bool, optional): - If True, the existing factors with the same `factors_name` will be overwritten. + Whether to overwrite existing output. load_from_factors_name (str, optional): The `factor_name` to load covariance matrices from. By default, covariance matrices with the same `factor_name` will be used. """ - self.logger.debug(f"Performing Eigendecomposition with parameters: {locals()}") + self.logger.debug(f"Performing eigendecomposition with parameters: {locals()}") factors_output_dir = self.factors_output_dir(factors_name=factors_name) os.makedirs(factors_output_dir, exist_ok=True) if eigendecomposition_exist(output_dir=factors_output_dir) and not overwrite_output_dir: - self.logger.info(f"Found existing Eigendecomposition results at `{factors_output_dir}`. Skipping.") + self.logger.info(f"Found existing eigendecomposition results at `{factors_output_dir}`. Skipping.") return factor_args = self._configure_and_save_factor_args( @@ -418,7 +411,7 @@ def perform_eigendecomposition( if not FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition: self.logger.info( - f"Strategy `{factor_args.strategy}` does not require performing Eigendecomposition. Skipping." + f"Strategy `{factor_args.strategy}` does not require performing eigendecomposition. Skipping." ) return None @@ -430,7 +423,7 @@ def perform_eigendecomposition( if not covariance_matrices_exist(output_dir=load_factors_output_dir): error_msg = ( f"Covariance matrices not found at `{load_factors_output_dir}`. " - f"To perform Eigendecomposition, covariance matrices need to be first computed." + f"To perform eigendecomposition, covariance matrices need to be first computed." ) self.logger.error(error_msg) raise FactorsNotFoundError(error_msg) @@ -439,7 +432,7 @@ def perform_eigendecomposition( covariance_factors = load_covariance_matrices(output_dir=load_factors_output_dir) if load_from_factors_name is not None and self.state.is_main_process: - # Saves the loaded covariances to the current factor output directory. + # Save the loaded covariances to the current factor output directory. with self.profiler.profile("Save Covariance"): save_covariance_matrices(output_dir=factors_output_dir, factors=covariance_factors) loaded_factor_args = self.load_factor_args(factors_name=load_from_factors_name) @@ -451,9 +444,9 @@ def perform_eigendecomposition( ) self.state.wait_for_everyone() + self._reset_memory() eigen_factors = None if self.state.is_main_process: - release_memory() start_time = time.time() with self.profiler.profile("Perform Eigendecomposition"): eigen_factors = perform_eigendecomposition( @@ -465,15 +458,17 @@ def perform_eigendecomposition( ) end_time = time.time() elapsed_time = end_time - start_time - self.logger.info(f"Performed Eigendecomposition in {elapsed_time:.2f} seconds.") + self.logger.info(f"Performed eigendecomposition in {elapsed_time:.2f} seconds.") with self.profiler.profile("Save Eigendecomposition"): save_eigendecomposition( output_dir=factors_output_dir, factors=eigen_factors, metadata=factor_args.to_str_dict() ) - self.logger.info(f"Saved Eigendecomposition results at `{factors_output_dir}`.") + self.logger.info(f"Saved eigendecomposition results at `{factors_output_dir}`.") + del eigen_factors + self._reset_memory() self.state.wait_for_everyone() - self._log_profile_summary() + self._log_profile_summary(name=f"factors_{factors_name}_eigendecomposition") def fit_lambda_matrices( self, @@ -494,7 +489,7 @@ def fit_lambda_matrices( factors_name (str): The unique identifier for the factor, used to organize and retrieve the results. dataset (data.Dataset): - The dataset that will be used to fit Lambda matrices. + The dataset that will be used for fitting Lambda matrices. per_device_batch_size (int, optional): The per-device batch size used to fit the factors. If not specified, executable batch size is automatically determined. @@ -503,8 +498,7 @@ def fit_lambda_matrices( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, the default values of - `FactorArguments` will be used. + Arguments for factor computation. target_data_partitions(Sequence[int], int, optional): The list of data partition to fit Lambda matrices. By default, Lambda matrices will be computed for all partitions. @@ -512,9 +506,9 @@ def fit_lambda_matrices( The list of module partition to fit Lambda matrices. By default, Lambda matrices will be computed for all partitions. overwrite_output_dir (bool, optional): - If True, the existing factors with the same `factors_name` will be overwritten. + Whether to overwrite existing output. load_from_factors_name (str, optional): - The `factor_name` to load Eigendecomposition results from. By default, Eigendecomposition + The `factor_name` to load eigendecomposition results from. By default, eigendecomposition results with the same `factor_name` will be used. """ self.logger.debug(f"Fitting Lambda matrices with parameters: {locals()}") @@ -544,7 +538,7 @@ def fit_lambda_matrices( if load_from_factors_name is not None: self.logger.info( - f"Will be loading Eigendecomposition results from factors with name `{load_from_factors_name}`." + f"Will be loading eigendecomposition results from factors with name `{load_from_factors_name}`." ) load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) else: @@ -556,7 +550,7 @@ def fit_lambda_matrices( ): error_msg = ( f"Eigendecomposition results not found at `{load_factors_output_dir}`. " - f"To fit Lambda matrices for `{factor_args.strategy}`, Eigendecomposition must be " + f"To fit Lambda matrices for `{factor_args.strategy}`, eigendecomposition must be " f"performed before computing Lambda matrices." ) self.logger.error(error_msg) @@ -584,7 +578,7 @@ def fit_lambda_matrices( total_data_examples = min([factor_args.lambda_max_examples, len(dataset)]) self.logger.info(f"Total data examples to fit Lambda matrices: {total_data_examples}.") - no_partition = factor_args.lambda_data_partition_size == 1 and factor_args.lambda_module_partition_size == 1 + no_partition = factor_args.lambda_data_partitions == 1 and factor_args.lambda_module_partitions == 1 partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -596,17 +590,17 @@ def fit_lambda_matrices( data_partition_indices, target_data_partitions = self._get_data_partition( total_data_examples=total_data_examples, - data_partition_size=factor_args.lambda_data_partition_size, + data_partitions=factor_args.lambda_data_partitions, target_data_partitions=target_data_partitions, ) - max_partition_examples = total_data_examples // factor_args.lambda_data_partition_size + max_partition_examples = total_data_examples // factor_args.lambda_data_partitions module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=factor_args.lambda_module_partition_size, + module_partitions=factor_args.lambda_module_partitions, target_module_partitions=target_module_partitions, ) if max_partition_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples. Try reducing the number of processes." + error_msg = "The number of processes are larger than total data examples. Try reducing number of processes." self.logger.error(error_msg) raise ValueError(error_msg) @@ -656,7 +650,7 @@ def fit_lambda_matrices( total_data_examples=max_partition_examples, ) - release_memory() + self._reset_memory() start_time = get_time(state=self.state) with self.profiler.profile("Fit Lambda"): loader = self._get_dataloader( @@ -692,8 +686,9 @@ def fit_lambda_matrices( metadata=factor_args.to_str_dict(), ) self.state.wait_for_everyone() - del lambda_factors, loader self.logger.info(f"Saved Lambda matrices at `{factors_output_dir}`.") + del num_data_processed, lambda_factors, loader + self._reset_memory() all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time @@ -703,7 +698,7 @@ def fit_lambda_matrices( self.aggregate_lambda_matrices(factors_name=factors_name) self.logger.info(f"Saved aggregated Lambda matrices at `{factors_output_dir}`.") self.state.wait_for_everyone() - self._log_profile_summary() + self._log_profile_summary(name=f"factors_{factors_name}_lambda") @torch.no_grad() def aggregate_lambda_matrices( @@ -721,7 +716,7 @@ def aggregate_lambda_matrices( if factor_args is None: error_msg = ( f"Arguments for factors with name `{factors_name}` was not found when trying " - f"to aggregated Lambda matrices." + f"to aggregate Lambda matrices." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -729,9 +724,9 @@ def aggregate_lambda_matrices( with self.profiler.profile("Aggregate Lambda"): self._aggregate_factors( factors_name=factors_name, - data_partition_size=factor_args.lambda_data_partition_size, - module_partition_size=factor_args.lambda_module_partition_size, - exists_fnc=lambda_matrices_exist, + data_partitions=factor_args.lambda_data_partitions, + module_partitions=factor_args.lambda_module_partitions, + exist_fnc=lambda_matrices_exist, load_fnc=load_lambda_matrices, save_fnc=save_lambda_matrices, ) diff --git a/kronfluence/computer/score_computer.py b/kronfluence/computer/score_computer.py index 2b8fc43..d606124 100644 --- a/kronfluence/computer/score_computer.py +++ b/kronfluence/computer/score_computer.py @@ -8,9 +8,8 @@ from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.computer.computer import Computer -from kronfluence.module.tracked_module import ModuleMode -from kronfluence.module.utils import set_mode from kronfluence.score.pairwise import ( + compute_pairwise_query_aggregated_scores_with_loaders, compute_pairwise_scores_with_loaders, load_pairwise_scores, pairwise_scores_exist, @@ -23,12 +22,15 @@ save_self_scores, self_scores_exist, ) -from kronfluence.utils.constants import FACTOR_TYPE, SCORE_TYPE +from kronfluence.utils.constants import ( + FACTOR_ARGUMENTS_NAME, + FACTOR_TYPE, + SCORE_ARGUMENTS_NAME, + SCORE_TYPE, +) from kronfluence.utils.dataset import DataLoaderKwargs, find_executable_batch_size from kronfluence.utils.exceptions import FactorsNotFoundError from kronfluence.utils.logger import get_time -from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME, SCORE_ARGUMENTS_NAME -from kronfluence.utils.state import release_memory class ScoreComputer(Computer): @@ -41,7 +43,7 @@ def _configure_and_save_score_args( factors_name: str, overwrite_output_dir: bool, ) -> Tuple[FactorArguments, ScoreArguments]: - """Configures the provided score arguments and saves it in disk.""" + """Configures and saves score arguments to disk.""" if score_args is None: score_args = ScoreArguments() self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.") @@ -51,7 +53,7 @@ def _configure_and_save_score_args( factor_args = self.load_factor_args(factors_name=factors_name) factors_output_dir = self.factors_output_dir(factors_name=factors_name) if factor_args is None: - error_msg = f"Factors with name `{factors_name}` was not found at `{factors_output_dir}`." + error_msg = f"Factors with name `{factors_name}` not found at `{factors_output_dir}`." self.logger.error(error_msg) raise FactorsNotFoundError(error_msg) self.logger.info(f"Loaded `FactorArguments` with configuration: {factor_args}.") @@ -77,7 +79,7 @@ def _aggregate_scores( self, scores_name: str, score_args: ScoreArguments, - exists_fnc: Callable, + exist_fnc: Callable, load_fnc: Callable, save_fnc: Callable, dim: int, @@ -85,47 +87,43 @@ def _aggregate_scores( """Aggregates influence scores computed for all data and module partitions.""" scores_output_dir = self.scores_output_dir(scores_name=scores_name) if not scores_output_dir.exists(): - error_msg = ( - f"Scores output directory `{scores_output_dir}` is not found " - f"when trying to aggregate partitioned scores." - ) + error_msg = f"Scores directory `{scores_output_dir}` not found when trying to aggregate scores." self.logger.error(error_msg) raise FileNotFoundError(error_msg) - data_partition_size = score_args.data_partition_size - module_partition_size = score_args.module_partition_size all_required_partitions = [ - (i, j) for i in range(score_args.data_partition_size) for j in range(score_args.module_partition_size) + (i, j) for i in range(score_args.data_partitions) for j in range(score_args.module_partitions) ] all_partition_exists = all( - exists_fnc(output_dir=scores_output_dir, partition=partition) for partition in all_required_partitions + exist_fnc(output_dir=scores_output_dir, partition=partition) for partition in all_required_partitions ) if not all_partition_exists: - self.logger.info("Influence scores are not aggregated as scores for some partitions are not yet computed.") + self.logger.info("Scores are not aggregated as scores for some partitions are not yet computed.") return start_time = time.time() aggregated_scores: SCORE_TYPE = {} - with self.profiler.profile("Aggregate Score"): - for data_partition in range(data_partition_size): - aggregated_module_scores = {} + for data_partition in range(score_args.data_partitions): + aggregated_module_scores = {} - for module_partition in range(module_partition_size): - loaded_scores = load_fnc( - output_dir=scores_output_dir, - partition=(data_partition, module_partition), - ) + for module_partition in range(score_args.module_partitions): + loaded_scores = load_fnc( + output_dir=scores_output_dir, + partition=(data_partition, module_partition), + ) - for module_name, scores in loaded_scores.items(): - if module_name not in aggregated_module_scores: - aggregated_module_scores[module_name] = scores - else: - aggregated_module_scores[module_name].add_(scores) - del loaded_scores + for module_name, scores in loaded_scores.items(): + if module_name not in aggregated_module_scores: + aggregated_module_scores[module_name] = torch.zeros_like(scores, requires_grad=False) + aggregated_module_scores[module_name].add_(scores) + del loaded_scores - for module_name, scores in aggregated_module_scores.items(): - if module_name not in aggregated_scores: - aggregated_scores[module_name] = scores + for module_name, scores in aggregated_module_scores.items(): + if module_name not in aggregated_scores: + aggregated_scores[module_name] = scores.clone() + else: + if score_args.aggregate_train_gradients: + aggregated_scores[module_name].add_(scores) else: aggregated_scores[module_name] = torch.cat( ( @@ -134,10 +132,10 @@ def _aggregate_scores( ), dim=dim, ) - save_fnc(output_dir=scores_output_dir, scores=aggregated_scores, metadata=score_args.to_str_dict()) + save_fnc(output_dir=scores_output_dir, scores=aggregated_scores, metadata=score_args.to_str_dict()) end_time = time.time() elapsed_time = end_time - start_time - self.logger.info(f"Aggregated all partitioned scores in {elapsed_time:.2f} seconds.") + self.logger.info(f"Aggregated all scores in {elapsed_time:.2f} seconds.") return aggregated_scores def _find_executable_pairwise_scores_batch_size( @@ -156,8 +154,8 @@ def _find_executable_pairwise_scores_batch_size( """Automatically finds executable training batch size for computing pairwise influence scores.""" if self.state.use_distributed: error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " - "Please manually configure the batch size by passing in `per_device_train_batch_size`." + "Automatic batch size search is not supported for multi-GPU setting. " + "Please manually configure the batch size by passing in `per_device_batch_size`." ) self.logger.error(error_msg) raise NotImplementedError(error_msg) @@ -174,9 +172,7 @@ def _find_executable_pairwise_scores_batch_size( def executable_batch_size_func(batch_size: int) -> None: self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") # Releases all memory that could be caused by the previous OOM. - self.model.zero_grad(set_to_none=True) - set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) - release_memory() + self._reset_memory() total_batch_size = batch_size * self.state.num_processes query_loader = self._get_dataloader( dataset=query_dataset, @@ -193,7 +189,12 @@ def executable_batch_size_func(batch_size: int) -> None: allow_duplicates=True, stack=True, ) - compute_pairwise_scores_with_loaders( + func = ( + compute_pairwise_scores_with_loaders + if not score_args.aggregate_query_gradients + else compute_pairwise_query_aggregated_scores_with_loaders + ) + func( model=self.model, state=self.state, task=self.task, @@ -231,9 +232,7 @@ def compute_pairwise_scores( target_module_partitions: Optional[Sequence[int]] = None, overwrite_output_dir: bool = False, ) -> Optional[SCORE_TYPE]: - """Computes pairwise influence scores for the given score configuration. As an example, - for Q query dataset and T training dataset, the pairwise influence scores are - 2-dimensional matrix with dimension `Q x T`. + """Computes pairwise influence scores with the given score configuration. Args: scores_name (str): @@ -260,8 +259,7 @@ def compute_pairwise_scores( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. score_args (ScoreArguments, optional): - Arguments related to computing the pairwise scores. If not specified, the default values - of `ScoreArguments` will be used. + Arguments for score computation. target_data_partitions (Sequence[int], optional): Specific data partitions to compute influence scores. If not specified, scores for all data partitions will be computed. @@ -269,7 +267,7 @@ def compute_pairwise_scores( Specific module partitions to compute influence scores. If not specified, scores for all module partitions will be computed. overwrite_output_dir (bool, optional): - If True, the existing factors with the same name will be overwritten. + Whether to overwrite existing output. """ self.logger.debug(f"Computing pairwise scores with parameters: {locals()}") @@ -286,6 +284,30 @@ def compute_pairwise_scores( overwrite_output_dir=overwrite_output_dir, ) + if score_args.compute_per_token_scores and score_args.aggregate_train_gradients: + warning_msg = ( + "Token-wise influence computation is not compatible with `aggregate_train_gradients=True`. " + "Disabling `compute_per_token_scores`." + ) + score_args.compute_per_token_scores = False + self.logger.warning(warning_msg) + + if score_args.compute_per_token_scores and factor_args.has_shared_parameters: + warning_msg = ( + "Token-wise influence computation is not compatible with `has_shared_parameters=True`. " + "Disabling `compute_per_token_scores`." + ) + score_args.compute_per_token_scores = False + self.logger.warning(warning_msg) + + if score_args.compute_per_token_scores and self.task.enable_post_process_per_sample_gradient: + warning_msg = ( + "Token-wise influence computation is not compatible with tasks that requires " + "`enable_post_process_per_sample_gradient`. Disabling `compute_per_token_scores`." + ) + score_args.compute_per_token_scores = False + self.logger.warning(warning_msg) + dataloader_params = self._configure_dataloader(dataloader_kwargs) if self.state.is_main_process: self._save_dataset_metadata( @@ -315,7 +337,7 @@ def compute_pairwise_scores( factors_name=factors_name, ) - no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + no_partition = score_args.data_partitions == 1 and score_args.module_partitions == 1 partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -327,12 +349,12 @@ def compute_pairwise_scores( data_partition_indices, target_data_partitions = self._get_data_partition( total_data_examples=len(train_dataset), - data_partition_size=score_args.data_partition_size, + data_partitions=score_args.data_partitions, target_data_partitions=target_data_partitions, ) - max_partition_examples = len(train_dataset) // score_args.data_partition_size + max_partition_examples = len(train_dataset) // score_args.data_partitions module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=score_args.module_partition_size, + module_partitions=score_args.module_partitions, target_module_partitions=target_module_partitions, ) @@ -359,14 +381,16 @@ def compute_pairwise_scores( start_index, end_index = data_partition_indices[data_partition] self.logger.info( - f"Fitting pairwise scores with data indices ({start_index}, {end_index}) and " + f"Computing pairwise scores with data indices ({start_index}, {end_index}) and " f"modules {module_partition_names[module_partition]}." ) if per_device_train_batch_size is None: per_device_train_batch_size = self._find_executable_pairwise_scores_batch_size( query_dataset=query_dataset, - per_device_query_batch_size=per_device_query_batch_size, + per_device_query_batch_size=per_device_query_batch_size + if not score_args.aggregate_query_gradients + else 1, train_dataset=train_dataset, initial_per_device_train_batch_size_attempt=initial_per_device_train_batch_size_attempt, loaded_factors=loaded_factors, @@ -377,24 +401,29 @@ def compute_pairwise_scores( tracked_modules_name=module_partition_names[module_partition], ) - release_memory() + self._reset_memory() start_time = get_time(state=self.state) with self.profiler.profile("Compute Pairwise Score"): query_loader = self._get_dataloader( dataset=query_dataset, per_device_batch_size=per_device_query_batch_size, dataloader_params=dataloader_params, - allow_duplicates=True, + allow_duplicates=not score_args.aggregate_query_gradients, ) train_loader = self._get_dataloader( dataset=train_dataset, per_device_batch_size=per_device_train_batch_size, indices=list(range(start_index, end_index)), dataloader_params=dataloader_params, - allow_duplicates=True, - stack=True, + allow_duplicates=not score_args.aggregate_train_gradients, + stack=not score_args.aggregate_train_gradients, ) - scores = compute_pairwise_scores_with_loaders( + func = ( + compute_pairwise_scores_with_loaders + if not score_args.aggregate_query_gradients + else compute_pairwise_query_aggregated_scores_with_loaders + ) + scores = func( model=self.model, state=self.state, task=self.task, @@ -421,6 +450,7 @@ def compute_pairwise_scores( ) self.state.wait_for_everyone() del scores, query_loader, train_loader + self._reset_memory() self.logger.info(f"Saved pairwise scores at {scores_output_dir}.") all_end_time = get_time(state=self.state) @@ -431,7 +461,7 @@ def compute_pairwise_scores( self.aggregate_pairwise_scores(scores_name=scores_name) self.logger.info(f"Saved aggregated pairwise scores at `{scores_output_dir}`.") self.state.wait_for_everyone() - self._log_profile_summary() + self._log_profile_summary(name=f"scores_{scores_name}_pairwise") @torch.no_grad() def aggregate_pairwise_scores(self, scores_name: str) -> None: @@ -446,19 +476,20 @@ def aggregate_pairwise_scores(self, scores_name: str) -> None: if score_args is None: error_msg = ( f"Arguments for scores with name `{score_args}` was not found when trying " - f"to aggregated pairwise influence scores." + f"to aggregate pairwise influence scores." ) self.logger.error(error_msg) raise ValueError(error_msg) - self._aggregate_scores( - scores_name=scores_name, - score_args=score_args, - exists_fnc=pairwise_scores_exist, - load_fnc=load_pairwise_scores, - save_fnc=save_pairwise_scores, - dim=1, - ) + with self.profiler.profile("Aggregate Score"): + self._aggregate_scores( + scores_name=scores_name, + score_args=score_args, + exist_fnc=pairwise_scores_exist, + load_fnc=load_pairwise_scores, + save_fnc=save_pairwise_scores, + dim=1, + ) def _find_executable_self_scores_batch_size( self, @@ -474,8 +505,8 @@ def _find_executable_self_scores_batch_size( """Automatically finds executable training batch size for computing self-influence scores.""" if self.state.use_distributed: error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " - "Please manually configure the batch size by passing in `per_device_train_batch_size`." + "Automatic batch size search is not supported for multi-GPU setting. " + "Please manually configure the batch size by passing in `per_device_batch_size`." ) self.logger.error(error_msg) raise NotImplementedError(error_msg) @@ -491,9 +522,7 @@ def _find_executable_self_scores_batch_size( def executable_batch_size_func(batch_size: int) -> None: self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") # Releases all memory that could be caused by the previous OOM. - self.model.zero_grad(set_to_none=True) - set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) - release_memory() + self._reset_memory() total_batch_size = batch_size * self.state.num_processes train_loader = self._get_dataloader( dataset=train_dataset, @@ -540,8 +569,7 @@ def compute_self_scores( target_module_partitions: Optional[Sequence[int]] = None, overwrite_output_dir: bool = False, ) -> Optional[SCORE_TYPE]: - """Computes self-influence scores for the given score configuration. As an example, - for training dataset with T examples, the self-influence scores are represented as T-dimensional vector. + """Computes self-influence scores with the given score configuration. Args: scores_name (str): @@ -561,8 +589,7 @@ def compute_self_scores( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. score_args (ScoreArguments, optional): - Arguments related to computing the self-influence scores. If not specified, the default values - of `ScoreArguments` will be used. + Arguments for score computation. target_data_partitions (Sequence[int], optional): Specific data partitions to compute influence scores. If not specified, scores for all data partitions will be computed. @@ -570,7 +597,7 @@ def compute_self_scores( Specific module partitions to compute influence scores. If not specified, scores for all module partitions will be computed. overwrite_output_dir (bool, optional): - If True, the existing factors with the same name will be overwritten. + Whether to overwrite existing output. """ self.logger.debug(f"Computing self-influence scores with parameters: {locals()}") @@ -587,11 +614,28 @@ def compute_self_scores( overwrite_output_dir=overwrite_output_dir, ) - if score_args.query_gradient_rank is not None: + if score_args.query_gradient_accumulation_steps != 1: + warning_msg = "Query gradient accumulation is not supported for self-influence computation." + score_args.query_gradient_accumulation_steps = 1 + self.logger.warning(warning_msg) + + if score_args.query_gradient_low_rank is not None: warning_msg = ( "Low rank query gradient approximation is not supported for self-influence computation. " "No low rank query approximation will be performed." ) + score_args.query_gradient_low_rank = None + self.logger.warning(warning_msg) + + if score_args.aggregate_query_gradients or score_args.aggregate_train_gradients: + warning_msg = "Query or train gradient aggregation is not supported for self-influence computation." + score_args.aggregate_train_gradients = False + score_args.aggregate_query_gradients = False + self.logger.warning(warning_msg) + + if score_args.compute_per_token_scores: + warning_msg = "Token-wise influence computation is not compatible with self-influence scores. " + score_args.compute_per_token_scores = False self.logger.warning(warning_msg) dataloader_params = self._configure_dataloader(dataloader_kwargs) @@ -612,7 +656,7 @@ def compute_self_scores( factors_name=factors_name, ) - no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + no_partition = score_args.data_partitions == 1 and score_args.module_partitions == 1 partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -624,12 +668,12 @@ def compute_self_scores( data_partition_indices, target_data_partitions = self._get_data_partition( total_data_examples=len(train_dataset), - data_partition_size=score_args.data_partition_size, + data_partitions=score_args.data_partitions, target_data_partitions=target_data_partitions, ) - max_partition_examples = len(train_dataset) // score_args.data_partition_size + max_partition_examples = len(train_dataset) // score_args.data_partitions module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=score_args.module_partition_size, + module_partitions=score_args.module_partitions, target_module_partitions=target_module_partitions, ) @@ -656,7 +700,7 @@ def compute_self_scores( start_index, end_index = data_partition_indices[data_partition] self.logger.info( - f"Fitting self-influence scores with data indices ({start_index}, {end_index}) and " + f"Computing self-influence scores with data indices ({start_index}, {end_index}) and " f"modules {module_partition_names[module_partition]}." ) @@ -672,7 +716,7 @@ def compute_self_scores( tracked_modules_name=module_partition_names[module_partition], ) - release_memory() + self._reset_memory() start_time = get_time(state=self.state) with self.profiler.profile("Compute Self-Influence Score"): train_loader = self._get_dataloader( @@ -712,6 +756,7 @@ def compute_self_scores( ) self.state.wait_for_everyone() del scores, train_loader + self._reset_memory() self.logger.info(f"Saved self-influence scores at `{scores_output_dir}`.") all_end_time = get_time(state=self.state) @@ -722,7 +767,7 @@ def compute_self_scores( self.aggregate_self_scores(scores_name=scores_name) self.logger.info(f"Saved aggregated self-influence scores at `{scores_output_dir}`.") self.state.wait_for_everyone() - self._log_profile_summary() + self._log_profile_summary(name=f"scores_{scores_name}_self") @torch.no_grad() def aggregate_self_scores(self, scores_name: str) -> None: @@ -737,15 +782,16 @@ def aggregate_self_scores(self, scores_name: str) -> None: if score_args is None: error_msg = ( f"Arguments for scores with name `{score_args}` was not found when trying " - f"to aggregated self-influence scores." + f"to aggregate self-influence scores." ) self.logger.error(error_msg) raise ValueError(error_msg) + score_args.aggregate_query_gradients = score_args.aggregate_train_gradients = False self._aggregate_scores( scores_name=scores_name, score_args=score_args, - exists_fnc=self_scores_exist, + exist_fnc=self_scores_exist, load_fnc=load_self_scores, save_fnc=save_self_scores, dim=0, diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index 976a261..9b7591c 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -40,32 +40,32 @@ def __init_subclass__(cls, factor_strategy: Optional[FactorStrategy] = None, **k @property @abstractmethod def requires_covariance_matrices(self) -> bool: - """Returns True if the strategy requires computing covariance matrices.""" + """Returns `True` if the strategy requires computing covariance matrices.""" raise NotImplementedError("Subclasses must implement the `requires_covariance_matrices` property.") @property @abstractmethod def requires_eigendecomposition(self) -> bool: - """Returns True if the strategy requires performing Eigendecomposition.""" + """Returns `True` if the strategy requires performing Eigendecomposition.""" raise NotImplementedError("Subclasses must implement the `requires_eigendecomposition` property.") @property @abstractmethod def requires_lambda_matrices(self) -> bool: - """Returns True if the strategy requires computing Lambda matrices.""" + """Returns `True` if the strategy requires computing Lambda matrices.""" raise NotImplementedError("Subclasses must implement the `requires_lambda_matrices` property.") @property @abstractmethod def requires_eigendecomposition_for_lambda(self) -> bool: - """Returns True if the strategy requires loading Eigendecomposition results, before computing + """Returns `True` if the strategy requires loading Eigendecomposition results, before computing Lambda matrices.""" raise NotImplementedError("Subclasses must implement the `requires_eigendecomposition_for_lambda` property.") @property @abstractmethod def requires_covariance_matrices_for_precondition(self) -> bool: - """Returns True if the strategy requires loading covariance matrices, before computing + """Returns `True` if the strategy requires loading covariance matrices, before computing preconditioned gradient.""" raise NotImplementedError( "Subclasses must implement the `requires_covariance_matrices_for_precondition` property." @@ -74,7 +74,7 @@ def requires_covariance_matrices_for_precondition(self) -> bool: @property @abstractmethod def requires_eigendecomposition_for_precondition(self) -> bool: - """Returns True if the strategy requires loading Eigendecomposition results, before computing + """Returns `True` if the strategy requires loading Eigendecomposition results, before computing preconditioned gradient.""" raise NotImplementedError( "Subclasses must implement the `requires_eigendecomposition_for_precondition` property." @@ -83,34 +83,42 @@ def requires_eigendecomposition_for_precondition(self) -> bool: @property @abstractmethod def requires_lambda_matrices_for_precondition(self) -> bool: - """Returns True if the strategy requires loading Lambda matrices, before computing + """Returns `True` if the strategy requires loading Lambda matrices, before computing the preconditioned gradient.""" raise NotImplementedError("Subclasses must implement the `requires_lambda_matrices_for_precondition` property.") + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + """Performs necessary operations before computing the preconditioned gradient. + + Args: + storage (STORAGE_TYPE): + A dictionary containing various factors required to compute the preconditioned gradient. + See `.storage` in `TrackedModule` for details. + score_args (ScoreArguments): + Arguments for computing the preconditioned gradient. + device (torch.device): + Device used for computing the preconditioned gradient. + """ + @abstractmethod def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - """Preconditions the per-sample-gradient. The per-sample-gradient is a 3-dimensional - tensor with the shape `batch_size x output_dim x input_dim`. + """Preconditions the per-sample gradient. The per-sample gradient is a 3-dimensional + tensor with shape `batch_size x output_dim x input_dim`. Args: gradient (torch.Tensor): - The per-sample-gradient tensor. + The per-sample gradient tensor. storage (STORAGE_TYPE): A dictionary containing various factors required to compute the preconditioned gradient. See `.storage` in `TrackedModule` for details. - damping (float, optional): - The damping factor when computing the preconditioned gradient. If not provided, sets - the damping term with some heuristic. Returns: torch.Tensor: - The preconditioned per-sample-gradient tensor. The dimension should be the same as the original - per-sample-gradient. + The preconditioned per-sample gradient tensor. """ raise NotImplementedError("Subclasses must implement the `precondition_gradient` method.") @@ -150,9 +158,8 @@ def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - del storage, damping + del storage return gradient @@ -187,18 +194,23 @@ def requires_eigendecomposition_for_precondition(self) -> bool: def requires_lambda_matrices_for_precondition(self) -> bool: return True + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) + damping_factor = score_args.damping_factor + if damping_factor is None: + damping_factor = 0.1 * torch.mean(lambda_matrix) + lambda_matrix.add_(damping_factor) + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() + storage[NUM_LAMBDA_PROCESSED] = None + def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) - num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) - lambda_matrix = lambda_matrix / num_lambda_processed - if damping is None: - damping = 0.1 * torch.mean(lambda_matrix) - return gradient / (lambda_matrix + damping) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) + return gradient / lambda_matrix class Kfac(FactorConfig, factor_strategy=FactorStrategy.KFAC): @@ -235,27 +247,38 @@ def requires_eigendecomposition_for_precondition(self) -> bool: def requires_lambda_matrices_for_precondition(self) -> bool: return False + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + storage[ACTIVATION_EIGENVECTORS_NAME] = ( + storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() + ) + storage[GRADIENT_EIGENVECTORS_NAME] = ( + storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() + ) + activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device) + gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device) + lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) + damping_factor = score_args.damping_factor + if damping_factor is None: + damping_factor = 0.1 * torch.mean(lambda_matrix) + lambda_matrix.add_(damping_factor) + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() + storage[NUM_LAMBDA_PROCESSED] = None + storage[ACTIVATION_EIGENVALUES_NAME] = None + storage[GRADIENT_EIGENVALUES_NAME] = None + @torch.no_grad() def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) - gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) - # The eigenvalues have the Kronecker structure for KFAC. - lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) - + activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(device=gradient.device) + gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - - if damping is None: - damping = 0.1 * torch.mean(lambda_matrix) - - gradient.div_(lambda_matrix + damping) - return torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) + gradient.div_(lambda_matrix) + gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) + return gradient class Ekfac(FactorConfig, factor_strategy=FactorStrategy.EKFAC): @@ -292,23 +315,34 @@ def requires_eigendecomposition_for_precondition(self) -> bool: def requires_lambda_matrices_for_precondition(self) -> bool: return True + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + storage[ACTIVATION_EIGENVECTORS_NAME] = ( + storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() + ) + storage[GRADIENT_EIGENVECTORS_NAME] = ( + storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() + ) + storage[ACTIVATION_EIGENVALUES_NAME] = None + storage[GRADIENT_EIGENVALUES_NAME] = None + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) + damping_factor = score_args.damping_factor + if damping_factor is None: + damping_factor = 0.1 * torch.mean(lambda_matrix) + lambda_matrix.add_(damping_factor) + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() + storage[NUM_LAMBDA_PROCESSED] = None + @torch.no_grad() def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) - num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) - lambda_matrix = lambda_matrix / num_lambda_processed + activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(device=gradient.device) + gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - - if damping is None: - damping = 0.1 * torch.mean(lambda_matrix) - - gradient.div_(lambda_matrix + damping) + gradient.div_(lambda_matrix) gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) return gradient diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 86ad2ca..3c7d37b 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -18,12 +18,13 @@ set_attention_mask, set_gradient_scale, set_mode, - synchronize_covariance_matrices, + synchronize_modules, update_factor_args, ) from kronfluence.task import Task from kronfluence.utils.constants import ( COVARIANCE_FACTOR_NAMES, + DISTRIBUTED_SYNC_INTERVAL, FACTOR_TYPE, PARTITION_TYPE, ) @@ -36,7 +37,24 @@ def covariance_matrices_save_path( factor_name: str, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading covariance matrices.""" + """Generates the path for saving or loading covariance matrices. + + Args: + output_dir (Path): + Directory to save or load the matrices. + factor_name (str): + Name of the factor (must be in `COVARIANCE_FACTOR_NAMES`). + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the covariance matrix file. + + Raises: + AssertionError: + If `factor_name` is not in `COVARIANCE_FACTOR_NAMES`. + """ assert factor_name in COVARIANCE_FACTOR_NAMES if partition is not None: data_partition, module_partition = partition @@ -52,7 +70,22 @@ def save_covariance_matrices( partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves covariance matrices to disk.""" + """Saves covariance matrices to disk. + + Args: + output_dir (Path): + Directory to save the matrices. + factors (FACTOR_TYPE): + Dictionary of factors to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the factors. + + Raises: + AssertionError: + If factors keys don't match `COVARIANCE_FACTOR_NAMES`. + """ assert set(factors.keys()) == set(COVARIANCE_FACTOR_NAMES) for factor_name in factors: save_path = covariance_matrices_save_path( @@ -67,7 +100,18 @@ def load_covariance_matrices( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> FACTOR_TYPE: - """Loads covariance matrices from disk.""" + """Loads covariance matrices from disk. + + Args: + output_dir (Path): + Directory to load the matrices from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + FACTOR_TYPE: + Dictionary of loaded covariance factors. + """ covariance_factors = {} for factor_name in COVARIANCE_FACTOR_NAMES: save_path = covariance_matrices_save_path( @@ -83,7 +127,18 @@ def covariance_matrices_exist( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> bool: - """Checks if covariance matrices exist at the specified directory.""" + """Checks if covariance matrices exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + bool: + `True` if all covariance matrices exist, `False` otherwise. + """ for factor_name in COVARIANCE_FACTOR_NAMES: save_path = covariance_matrices_save_path( output_dir=output_dir, @@ -121,25 +176,22 @@ def fit_covariance_matrices_with_loader( A list of module names for which covariance matrices will be computed. If not specified, covariance matrices will be computed for all tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: Tuple[torch.Tensor, FACTOR_TYPE]: - A tuple containing the number of data points processed and computed covariance matrices. - The covariance matrices are organized in nested dictionaries, where the first key is the name of the - covariance matrix (e.g., activation covariance and pseudo-gradient covariance) and the second key is - the module name. + - Number of data points processed. + - Computed covariance matrices (nested dict: factor_name -> module_name -> tensor). """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - tracked_module_names=tracked_module_names, - mode=ModuleMode.COVARIANCE, - keep_factors=False, - ) + update_factor_args(model=model, factor_args=factor_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + tracked_module_names=tracked_module_names, + mode=ModuleMode.COVARIANCE, + release_memory=True, + ) total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) @@ -157,13 +209,13 @@ def fit_covariance_matrices_with_loader( ) as pbar: for index, batch in enumerate(loader): batch = send_to_device(batch, device=state.device) - with torch.no_grad(): - attention_mask = task.get_attention_mask(batch=batch) - if attention_mask is not None: - set_attention_mask(model=model, attention_mask=attention_mask) - model.zero_grad(set_to_none=True) + attention_mask = task.get_attention_mask(batch=batch) + if attention_mask is not None: + set_attention_mask(model=model, attention_mask=attention_mask) + with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=factor_args.amp_dtype): loss = task.compute_train_loss( batch=batch, @@ -174,32 +226,40 @@ def fit_covariance_matrices_with_loader( if ( state.use_distributed - and total_steps % factor_args.distributed_sync_steps == 0 + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 and index not in [len(loader) - 1, len(loader) - 2] ): - # Periodically synchronizes all processes to avoid timeout at the final synchronization. state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) + del loss total_steps += 1 pbar.update(1) - with torch.no_grad(): - if state.use_distributed: - # Aggregates covariance matrices across multiple devices or nodes. - synchronize_covariance_matrices(model=model) - num_data_processed = num_data_processed.to(device=state.device) - dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + num_data_processed = num_data_processed.to(device=state.device) + dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + num_data_processed = num_data_processed.cpu() - saved_factors: FACTOR_TYPE = {} + saved_factors: FACTOR_TYPE = {} + if state.is_main_process: for factor_name in COVARIANCE_FACTOR_NAMES: - saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name, clone=True) - state.wait_for_everyone() - - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) + factor = load_factors( + model=model, + factor_name=factor_name, + tracked_module_names=tracked_module_names, + cpu=True, + ) + if len(factor) == 0: + raise ValueError(f"Factor `{factor_name}` has not been computed.") + saved_factors[factor_name] = factor + + model.zero_grad(set_to_none=True) + set_attention_mask(model=model, attention_mask=None) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index b9db412..ff027ea 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -14,13 +14,13 @@ from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( - finalize_lambda_matrices, + finalize_iteration, get_tracked_module_names, load_factors, set_factors, set_gradient_scale, set_mode, - synchronize_lambda_matrices, + synchronize_modules, update_factor_args, ) from kronfluence.task import Task @@ -28,6 +28,7 @@ ACTIVATION_COVARIANCE_MATRIX_NAME, ACTIVATION_EIGENVALUES_NAME, ACTIVATION_EIGENVECTORS_NAME, + DISTRIBUTED_SYNC_INTERVAL, EIGENDECOMPOSITION_FACTOR_NAMES, FACTOR_TYPE, GRADIENT_COVARIANCE_MATRIX_NAME, @@ -46,13 +47,41 @@ def eigendecomposition_save_path( output_dir: Path, factor_name: str, ) -> Path: - """Generates the path for saving/loading Eigendecomposition results.""" + """Generates the path for saving or loading eigendecomposition results. + + Args: + output_dir (Path): + Directory to save or load eigenvectors and eigenvalues. + factor_name (str): + Name of the factor (must be in `EIGENDECOMPOSITION_FACTOR_NAMES`). + + Returns: + Path: + The full path for the eigendecomposition file. + + Raises: + AssertionError: + If `factor_name` is not in `EIGENDECOMPOSITION_FACTOR_NAMES`. + """ assert factor_name in EIGENDECOMPOSITION_FACTOR_NAMES return output_dir / f"{factor_name}.safetensors" def save_eigendecomposition(output_dir: Path, factors: FACTOR_TYPE, metadata: Optional[Dict[str, str]] = None) -> None: - """Saves Eigendecomposition results to disk.""" + """Saves eigendecomposition results to disk. + + Args: + output_dir (Path): + Directory to save the eigenvectors and eigenvalues. + factors (FACTOR_TYPE): + Dictionary of factors to save. + metadata (Dict[str, str], optional): + Additional metadata to save with the factors. + + Raises: + AssertionError: + If factors keys don't match `EIGENDECOMPOSITION_FACTOR_NAMES`. + """ assert set(factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) for factor_name in factors: save_path = eigendecomposition_save_path( @@ -65,7 +94,16 @@ def save_eigendecomposition(output_dir: Path, factors: FACTOR_TYPE, metadata: Op def load_eigendecomposition( output_dir: Path, ) -> FACTOR_TYPE: - """Loads Eigendecomposition results from disk.""" + """Loads eigendecomposition results from disk. + + Args: + output_dir (Path): + Directory to load the results from. + + Returns: + FACTOR_TYPE: + Dictionary of loaded eigendecomposition results. + """ eigen_factors = {} for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: save_path = eigendecomposition_save_path( @@ -79,7 +117,16 @@ def load_eigendecomposition( def eigendecomposition_exist( output_dir: Path, ) -> bool: - """Checks if Eigendecomposition results exist at the specified path.""" + """Checks if eigendecomposition results exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for results. + + Returns: + bool: + `True` if all eigendecomposition results exist, `False` otherwise. + """ for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: save_path = eigendecomposition_save_path( output_dir=output_dir, @@ -98,26 +145,24 @@ def perform_eigendecomposition( factor_args: FactorArguments, disable_tqdm: bool = False, ) -> FACTOR_TYPE: - """Performs Eigendecomposition on activation and pseudo-gradient covariance matrices. + """Performs eigendecomposition on activation and pseudo-gradient covariance matrices. Args: covariance_factors (FACTOR_TYPE): - The model used to compute covariance matrices. + Computed covariance factors. model (nn.Module): - The model which contains modules which Eigendecomposition will be performed. + The model used to compute covariance matrices. state (State): The current process's information (e.g., device being used). factor_args (FactorArguments): - Arguments for computing Eigendecomposition. + Arguments for performing eigendecomposition. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: FACTOR_TYPE: - The Eigendecomposition results. These results are organized in nested dictionaries, where the first key - is the name of the factor (e.g.,activation eigenvector), and the second key is the module name. - disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + The results are organized in nested dictionaries, where the first key is the name of the factor + (e.g., activation eigenvector), and the second key is the module name. """ eigen_factors: FACTOR_TYPE = {} for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: @@ -150,7 +195,7 @@ def perform_eigendecomposition( device=state.device, dtype=factor_args.eigendecomposition_dtype, ) - # Normalizes covariance matrices. + # Normalize covariance matrices. covariance_matrix.div_(covariance_factors[num_processed_name][module_name].to(device=state.device)) # In cases where covariance matrices are not exactly symmetric due to numerical issues. covariance_matrix = covariance_matrix + covariance_matrix.t() @@ -165,9 +210,14 @@ def perform_eigendecomposition( eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix) else: raise - eigen_factors[eigenvalues_name][module_name] = eigenvalues.to(dtype=original_dtype).contiguous().cpu() - eigen_factors[eigenvectors_name][module_name] = eigenvectors.to(dtype=original_dtype).contiguous().cpu() - del covariance_matrix, eigenvalues, eigenvectors + del covariance_matrix + eigen_factors[eigenvalues_name][module_name] = eigenvalues.contiguous().to( + dtype=original_dtype, device="cpu" + ) + eigen_factors[eigenvectors_name][module_name] = eigenvectors.contiguous().to( + dtype=original_dtype, device="cpu" + ) + del eigenvalues, eigenvectors pbar.update(1) @@ -179,7 +229,24 @@ def lambda_matrices_save_path( factor_name: str, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading Lambda matrices.""" + """Generates the path for saving or loading Lambda matrices. + + Args: + output_dir (Path): + Directory to save or load the matrices. + factor_name (str): + Name of the factor (must be in `LAMBDA_FACTOR_NAMES`). + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the Lambda matrix file. + + Raises: + AssertionError: + If `factor_name` is not in `LAMBDA_FACTOR_NAMES`. + """ assert factor_name in LAMBDA_FACTOR_NAMES if partition is not None: data_partition, module_partition = partition @@ -195,7 +262,22 @@ def save_lambda_matrices( partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves Lambda matrices to disk.""" + """Saves Lambda matrices to disk. + + Args: + output_dir (Path): + Directory to save the matrices. + factors (FACTOR_TYPE): + Dictionary of factors to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the factors. + + Raises: + AssertionError: + If factors keys don't match `LAMBDA_FACTOR_NAMES`. + """ assert set(factors.keys()) == set(LAMBDA_FACTOR_NAMES) for factor_name in factors: save_path = lambda_matrices_save_path( @@ -210,7 +292,18 @@ def load_lambda_matrices( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> FACTOR_TYPE: - """Loads Lambda matrices from disk.""" + """Loads Lambda matrices from disk. + + Args: + output_dir (Path): + Directory to load the matrices from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + FACTOR_TYPE: + Dictionary of loaded Lambda factors. + """ lambda_factors = {} for factor_name in LAMBDA_FACTOR_NAMES: save_path = lambda_matrices_save_path( @@ -226,7 +319,18 @@ def lambda_matrices_exist( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> bool: - """Checks if Lambda matrices exist at the specified path.""" + """Checks if Lambda matrices exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + bool: + `True` if all Lambda matrices exist, `False` otherwise. + """ for factor_name in LAMBDA_FACTOR_NAMES: save_path = lambda_matrices_save_path( output_dir=output_dir, @@ -248,7 +352,7 @@ def fit_lambda_matrices_with_loader( tracked_module_names: Optional[List[str]] = None, disable_tqdm: bool = False, ) -> Tuple[torch.Tensor, FACTOR_TYPE]: - """Computes Lambda (corrected eigenvalues) matrices for a given model and task. + """Computes Lambda matrices for a given model and task. Args: model (nn.Module): @@ -262,32 +366,30 @@ def fit_lambda_matrices_with_loader( factor_args (FactorArguments): Arguments for computing Lambda matrices. eigen_factors (FACTOR_TYPE, optional): - The eigendecomposition results to use for computing Lambda matrices. + Computed eigendecomposition results. tracked_module_names (List[str], optional): A list of module names for which Lambda matrices will be computed. If not specified, Lambda matrices will be computed for all tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: Tuple[torch.Tensor, FACTOR_TYPE]: - A tuple containing the number of data points processed and computed Lambda matrices. - The Lambda matrices are organized in nested dictionaries, where the first key is the name of - the computed variable and the second key is the module name. + - Number of data points processed. + - Computed Lambda matrices (nested dict: factor_name -> module_name -> tensor). """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - tracked_module_names=tracked_module_names, - mode=ModuleMode.LAMBDA, - keep_factors=False, - ) - if eigen_factors is not None: - for name in eigen_factors: - set_factors(model=model, factor_name=name, factors=eigen_factors[name]) + update_factor_args(model=model, factor_args=factor_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + tracked_module_names=tracked_module_names, + mode=ModuleMode.LAMBDA, + release_memory=True, + ) + if eigen_factors is not None: + for name in eigen_factors: + set_factors(model=model, factor_name=name, factors=eigen_factors[name], clone=True) total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) @@ -306,8 +408,8 @@ def fit_lambda_matrices_with_loader( for index, batch in enumerate(loader): batch = send_to_device(tensor=batch, device=state.device) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=factor_args.amp_dtype): loss = task.compute_train_loss( batch=batch, @@ -316,40 +418,44 @@ def fit_lambda_matrices_with_loader( ) scaler.scale(loss).backward() - if factor_args.shared_parameters_exist: - # If shared parameter exists, Lambda matrices are computed and updated only after all - # per-sample-gradients are aggregated. - finalize_lambda_matrices(model=model) + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) if ( state.use_distributed - and total_steps % factor_args.distributed_sync_steps == 0 + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 and index not in [len(loader) - 1, len(loader) - 2] ): - # Periodically synchronizes all processes to avoid timeout at the final synchronization. state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) + del loss total_steps += 1 pbar.update(1) - with torch.no_grad(): - if state.use_distributed: - # Aggregates Lambda matrices across multiple devices or nodes. - synchronize_lambda_matrices(model=model) - num_data_processed = num_data_processed.to(device=state.device) - dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) - - saved_factors: FACTOR_TYPE = {} - if state.is_main_process: - for factor_name in LAMBDA_FACTOR_NAMES: - saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name, clone=True) - state.wait_for_everyone() - - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + num_data_processed = num_data_processed.to(device=state.device) + dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + num_data_processed = num_data_processed.cpu() + + saved_factors: FACTOR_TYPE = {} + if state.is_main_process: + for factor_name in LAMBDA_FACTOR_NAMES: + factor = load_factors( + model=model, + factor_name=factor_name, + tracked_module_names=tracked_module_names, + cpu=True, + ) + if len(factor) == 0: + raise ValueError(f"Factor `{factor_name}` has not been computed.") + saved_factors[factor_name] = factor + + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 21e7e1d..f7b606a 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -1,11 +1,11 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from einconv.utils import get_conv_paddings from einops import rearrange, reduce -from opt_einsum import contract -from torch import nn +from opt_einsum import DynamicProgramming, contract_path +from torch import _VF, nn from torch.nn.modules.utils import _pair from kronfluence.module.tracked_module import TrackedModule @@ -65,11 +65,45 @@ def extract_patches( class TrackedConv2d(TrackedModule, module_type=nn.Conv2d): - """A tracking wrapper for `nn.Conv2D` modules.""" + """A wrapper for `nn.Conv2d` modules.""" - def _get_flattened_activation( - self, input_activation: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + @property + def in_channels(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.in_channels + + @property + def out_channels(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.out_channels + + @property + def kernel_size(self) -> Tuple[int, int]: # pylint: disable=missing-function-docstring + return self.original_module.kernel_size + + @property + def padding(self) -> Tuple[int, int]: # pylint: disable=missing-function-docstring + return self.original_module.padding + + @property + def dilation(self) -> Tuple[int, int]: # pylint: disable=missing-function-docstring + return self.original_module.dilation + + @property + def groups(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.groups + + @property + def padding_mode(self) -> str: # pylint: disable=missing-function-docstring + return self.original_module.padding_mode + + @property + def weight(self) -> torch.Tensor: # pylint: disable=missing-function-docstring + return self.original_module.weight + + @property + def bias(self) -> Optional[torch.Tensor]: # pylint: disable=missing-function-docstring + return self.original_module.bias + + def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: input_activation = extract_patches( inputs=input_activation, kernel_size=self.original_module.kernel_size, @@ -82,7 +116,6 @@ def _get_flattened_activation( tensor=input_activation, pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) - if self.original_module.bias is not None: input_activation = torch.cat( [ @@ -94,16 +127,11 @@ def _get_flattened_activation( count = input_activation.size(0) return input_activation, count - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: output_gradient = rearrange(output_gradient, "b c o1 o2 -> (b o1 o2) c") return output_gradient, output_gradient.size(0) - @torch.no_grad() - def _compute_per_sample_gradient( - self, - input_activation: torch.Tensor, - output_gradient: torch.Tensor, - ) -> torch.Tensor: + def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Tensor: input_activation = extract_patches( inputs=input_activation, kernel_size=self.original_module.kernel_size, @@ -116,7 +144,6 @@ def _compute_per_sample_gradient( tensor=input_activation, pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) - if self.original_module.bias is not None: input_activation = torch.cat( [ @@ -125,6 +152,76 @@ def _compute_per_sample_gradient( ], dim=-1, ) + return input_activation + + def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) + output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") + summed_gradient = torch.einsum("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0) + return summed_gradient + + def compute_per_sample_gradient( + self, + input_activation: torch.Tensor, + output_gradient: torch.Tensor, + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) + output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") + per_sample_gradient = torch.einsum("bci,bco->bio", output_gradient, input_activation) + if self.per_sample_gradient_process_fnc is not None: + per_sample_gradient = self.per_sample_gradient_process_fnc( + module_name=self.name, gradient=per_sample_gradient + ) + return per_sample_gradient + + def compute_pairwise_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) + output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") + if isinstance(preconditioned_gradient, list): + left_mat, right_mat = preconditioned_gradient + expr = "qik,qko,b...i,b...o->qb" + if self.einsum_path is None: + path = contract_path( + expr, + left_mat, + right_mat, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (left_mat, right_mat, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member + expr = "qio,bti,bto->qb" + if self.einsum_path is None: + path = contract_path( + expr, + preconditioned_gradient, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member + + def compute_self_measurement_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - return contract("abm,abn->amn", output_gradient, input_activation) + expr = "bio,bci,bco->b" + if self.einsum_path is None: + path = contract_path( + expr, + preconditioned_gradient, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 6b89aae..2fcafec 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -1,25 +1,39 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from einops import rearrange -from opt_einsum import contract -from torch import nn +from opt_einsum import DynamicProgramming, contract_path +from torch import _VF, nn from kronfluence.module.tracked_module import TrackedModule class TrackedLinear(TrackedModule, module_type=nn.Linear): - """A tracking wrapper for `nn.Linear` modules.""" + """A wrapper for `nn.Linear` modules.""" - def _get_flattened_activation( - self, input_activation: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + @property + def in_features(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.in_features + + @property + def out_features(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.out_features + + @property + def weight(self) -> torch.Tensor: # pylint: disable=missing-function-docstring + return self.original_module.weight + + @property + def bias(self) -> Optional[torch.Tensor]: # pylint: disable=missing-function-docstring + return self.original_module.bias + + def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: flattened_activation = rearrange(tensor=input_activation, pattern="b ... d_in -> (b ...) d_in") flattened_attention_mask = None - if self._attention_mask is not None and flattened_activation.size(0) == self._attention_mask.numel(): + if self.attention_mask is not None and flattened_activation.size(0) == self.attention_mask.numel(): # If the binary attention mask is provided, zero-out appropriate activations. - flattened_attention_mask = rearrange(tensor=self._attention_mask, pattern="b ... -> (b ...) 1") + flattened_attention_mask = rearrange(tensor=self.attention_mask, pattern="b ... -> (b ...) 1") flattened_activation.mul_(flattened_attention_mask) if self.original_module.bias is not None: @@ -31,20 +45,94 @@ def _get_flattened_activation( count = flattened_activation.size(0) if flattened_attention_mask is None else flattened_attention_mask.sum() return flattened_activation, count - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: flattened_gradient = rearrange(tensor=output_gradient, pattern="b ... d_out -> (b ...) d_out") - if self._attention_mask is not None and flattened_gradient.size(0) == self._attention_mask.numel(): - count = self._attention_mask.sum() + if self.attention_mask is not None and flattened_gradient.size(0) == self.attention_mask.numel(): + count = self.attention_mask.sum() else: count = flattened_gradient.size(0) return flattened_gradient, count - @torch.no_grad() - def _compute_per_sample_gradient( - self, input_activation: torch.Tensor, output_gradient: torch.Tensor - ) -> torch.Tensor: + def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Tensor: if self.original_module.bias is not None: shape = list(input_activation.size()[:-1]) + [1] append_term = input_activation.new_ones(shape, requires_grad=False) input_activation = torch.cat([input_activation, append_term], dim=-1) - return contract("b...i,b...o->bio", output_gradient, input_activation) + return input_activation + + def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + summed_gradient = torch.einsum("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(dim=0) + return summed_gradient + + def compute_per_sample_gradient( + self, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + per_sample_gradient = torch.einsum("b...i,b...o->bio", output_gradient, input_activation) + if self.per_sample_gradient_process_fnc is not None: + per_sample_gradient = self.per_sample_gradient_process_fnc( + module_name=self.name, gradient=per_sample_gradient + ) + return per_sample_gradient + + def compute_pairwise_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + if isinstance(preconditioned_gradient, list): + left_mat, right_mat = preconditioned_gradient + if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: + expr = "qik,qko,bti,bto->qbt" + else: + expr = "qik,qko,b...i,b...o->qb" + if self.einsum_path is None: + path = contract_path( + expr, + left_mat, + right_mat, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (left_mat, right_mat, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member + if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: + expr = "qio,bti,bto->qbt" + if self.einsum_path is None: + path = contract_path( + expr, + preconditioned_gradient, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member + expr = "qio,b...i,b...o->qb" + if self.einsum_path is None: + path = contract_path( + expr, + preconditioned_gradient, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member + + def compute_self_measurement_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + expr = "bio,b...i,b...o->b" + if self.einsum_path is None: + path = contract_path( + expr, + preconditioned_gradient, + output_gradient, + input_activation, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.einsum_path = [item for pair in path for item in pair] + return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index e6161b1..22b90c7 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -1,38 +1,40 @@ from abc import abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch -import torch.distributed as dist from accelerate.utils.dataclasses import BaseEnum -from opt_einsum import contract from torch import nn -from torch.utils.hooks import RemovableHandle from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.module.tracker.factor import CovarianceTracker, LambdaTracker +from kronfluence.module.tracker.gradient import GradientTracker +from kronfluence.module.tracker.pairwise_score import PairwiseScoreTracker +from kronfluence.module.tracker.precondition import PreconditionTracker +from kronfluence.module.tracker.self_score import ( + SelfScoreTracker, + SelfScoreWithMeasurementTracker, +) from kronfluence.utils.constants import ( ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, - ACTIVATION_COVARIANCE_MATRIX_NAME, - ACTIVATION_EIGENVECTORS_NAME, + AGGREGATED_GRADIENT_NAME, COVARIANCE_FACTOR_NAMES, EIGENDECOMPOSITION_FACTOR_NAMES, - GRADIENT_COVARIANCE_MATRIX_NAME, - GRADIENT_EIGENVECTORS_NAME, LAMBDA_FACTOR_NAMES, - LAMBDA_MATRIX_NAME, - NUM_ACTIVATION_COVARIANCE_PROCESSED, - NUM_GRADIENT_COVARIANCE_PROCESSED, - NUM_LAMBDA_PROCESSED, PAIRWISE_SCORE_MATRIX_NAME, PRECONDITIONED_GRADIENT_NAME, + PRECONDITIONED_GRADIENT_TYPE, SELF_SCORE_VECTOR_NAME, ) -from kronfluence.utils.exceptions import FactorsNotFoundError class ModuleMode(str, BaseEnum): - """Enum to represent a module's mode, indicating which factors and scores need to be computed - during forward and backward passes.""" + """Enum representing a module's mode for factor and score computation. + + This enum indicates which factors and scores need to be computed during + forward and backward passes. + """ DEFAULT = "default" COVARIANCE = "covariance" @@ -41,15 +43,27 @@ class ModuleMode(str, BaseEnum): PAIRWISE_SCORE = "pairwise_score" SELF_SCORE = "self_score" SELF_MEASUREMENT_SCORE = "self_measurement_score" + GRADIENT_AGGREGATION = "gradient_aggregation" class TrackedModule(nn.Module): - """A wrapper class for PyTorch modules to compute influence factors and scores.""" + """A wrapper class for PyTorch modules to compute influence factors and scores. + + This class extends `nn.Module` to add functionality for tracking and computing + various influence-related metrics. + """ SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {} def __init_subclass__(cls, module_type: Type[nn.Module] = None, **kwargs: Any) -> None: - """Automatically registers subclasses as supported modules.""" + """Automatically registers subclasses as supported modules. + + Args: + module_type (Type[nn.Module], optional): + The type of module this subclass supports. + **kwargs: + Additional keyword arguments. + """ super().__init_subclass__(**kwargs) if module_type is not None: cls.SUPPORTED_MODULES[module_type] = cls @@ -60,8 +74,9 @@ def __init__( original_module: nn.Module, factor_args: Optional[FactorArguments] = None, score_args: Optional[ScoreArguments] = None, + per_sample_gradient_process_fnc: Optional[Callable] = None, ) -> None: - """Initializes an instance of the TrackedModule class. + """Initializes an instance of the `TrackedModule` class. Args: name (str): @@ -69,128 +84,241 @@ def __init__( original_module (nn.Module): The original module to be wrapped. factor_args (FactorArguments, optional): - Arguments for computing influence factors. + Arguments for computing factors. score_args (ScoreArguments, optional): Arguments for computing influence scores. + per_sample_gradient_process_fnc (Callable, optional): + Optional function to post-process per-sample gradients. """ super().__init__() self.name = name self.original_module = original_module - # A way to avoid Autograd computing the gradient with respect to the model parameters. self._constant: torch.Tensor = nn.Parameter( torch.zeros( 1, + dtype=self.original_module.weight.dtype, requires_grad=True, - dtype=torch.float16, ) ) + self.current_mode = ModuleMode.DEFAULT self.factor_args = FactorArguments() if factor_args is None else factor_args self.score_args = ScoreArguments() if score_args is None else score_args - - self._cached_activations: Optional[Union[List[torch.Tensor]], torch.Tensor] = None - self._cached_per_sample_gradient: Optional[torch.Tensor] = None - self._attention_mask: Optional[torch.Tensor] = None - self._gradient_scale: float = 1.0 - self._registered_hooks: List[RemovableHandle] = [] - self._storage: Dict[str, Optional[Union[torch.Tensor, List[torch.Tensor]]]] = {} - self._storage_at_device: bool = False - - # Storage for activation and pseudo-gradient covariance matrices. # + self.per_sample_gradient_process_fnc = per_sample_gradient_process_fnc + + self._trackers = { + ModuleMode.DEFAULT: BaseTracker(self), + ModuleMode.COVARIANCE: CovarianceTracker(self), + ModuleMode.LAMBDA: LambdaTracker(self), + ModuleMode.GRADIENT_AGGREGATION: GradientTracker(self), + ModuleMode.PRECONDITION_GRADIENT: PreconditionTracker(self), + ModuleMode.PAIRWISE_SCORE: PairwiseScoreTracker(self), + ModuleMode.SELF_SCORE: SelfScoreTracker(self), + ModuleMode.SELF_MEASUREMENT_SCORE: SelfScoreWithMeasurementTracker(self), + } + + self.attention_mask: Optional[torch.Tensor] = None + self.gradient_scale: float = 1.0 + self.storage: Dict[str, Optional[Union[torch.Tensor, PRECONDITIONED_GRADIENT_TYPE]]] = {} + self.einsum_path: Optional[List[int]] = None + self._initialize_storage() + + def _initialize_storage(self) -> None: + """Initializes storage for various factors and scores.""" + + # Storage for activation and pseudo-gradient covariance matrices # for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - self._storage[covariance_factor_name]: Optional[torch.Tensor] = None + self.storage[covariance_factor_name]: Optional[torch.Tensor] = None - # Storage for eigenvectors and eigenvalues. # + # Storage for eigenvectors and eigenvalues # for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: - self._storage[eigen_factor_name]: Optional[torch.Tensor] = None + self.storage[eigen_factor_name]: Optional[torch.Tensor] = None - # Storage for lambda matrices. # + # Storage for lambda matrices # for lambda_factor_name in LAMBDA_FACTOR_NAMES: - self._storage[lambda_factor_name]: Optional[torch.Tensor] = None + self.storage[lambda_factor_name]: Optional[torch.Tensor] = None + + # Storage for preconditioned gradients and influence scores # + self.storage[AGGREGATED_GRADIENT_NAME]: Optional[torch.Tensor] = None + self.storage[PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None + self.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None + self.storage[PAIRWISE_SCORE_MATRIX_NAME]: Optional[torch.Tensor] = None + self.storage[SELF_SCORE_VECTOR_NAME]: Optional[torch.Tensor] = None + + def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Performs a forward pass of the tracked module. - # Storage for preconditioned query gradients and influence scores. # - self._storage[PRECONDITIONED_GRADIENT_NAME]: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME]: Optional[Union[torch.Tensor, List[torch.Tensor]]] = ( - None + This method should have identical behavior to that of the original module. + + Args: + inputs (torch.Tensor): + Input tensor to the module. + *args: + Variable length argument list. + **kwargs: + Arbitrary keyword arguments. + + Returns: + torch.Tensor: + The output of the forward pass. + """ + outputs = self.original_module(inputs, *args, **kwargs) + if outputs.requires_grad: + return outputs + return outputs + self._constant + + def prepare_storage(self, device: torch.device) -> None: + """Prepares storage for computing influence scores. + + This method performs necessary operations on storage before computing influence scores. + + Args: + device (torch.device): + The device to prepare the storage for. + """ + FactorConfig.CONFIGS[self.factor_args.strategy].prepare( + storage=self.storage, + score_args=self.score_args, + device=device, ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME]: Optional[torch.Tensor] = None - self._storage[SELF_SCORE_VECTOR_NAME]: Optional[torch.Tensor] = None def update_factor_args(self, factor_args: FactorArguments) -> None: - """Updates the factor arguments.""" + """Updates the factor arguments. + + Args: + factor_args (FactorArguments): + New factor arguments to set. + """ self.factor_args = factor_args def update_score_args(self, score_args: ScoreArguments) -> None: - """Updates the score arguments.""" + """Updates the score arguments. + + Args: + score_args (ScoreArguments): + New score arguments to set. + """ self.score_args = score_args def get_factor(self, factor_name: str) -> Optional[torch.Tensor]: - """Returns the factor with the given name.""" - if factor_name not in self._storage: + """Retrieves a factor by name from storage. + + Args: + factor_name (str): + The name of the factor to retrieve. + + Returns: + Optional[torch.Tensor]: + The requested factor, or `None` if not found. + """ + if factor_name not in self.storage or self.storage[factor_name] is None: return None - return self._storage[factor_name] + return self.storage[factor_name] + + def release_factor(self, factor_name: str) -> None: + """Releases a factor from memory. + + Args: + factor_name (str): + The name of the factor to release. + """ + if factor_name not in self.storage or self.storage[factor_name] is None: + return None + del self.storage[factor_name] + self.storage[factor_name] = None def set_factor(self, factor_name: str, factor: Any) -> None: - """Sets the factor with the given name.""" - if factor_name in self._storage: - self._storage[factor_name] = factor - - def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: - """A forward pass of the tracked module. This should have identical behavior to that of the original module.""" - return self.original_module(inputs + self._constant, *args, **kwargs) - - def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None: - """Sets the module mode of the current `TrackedModule` instance.""" - self.set_attention_mask(attention_mask=None) - self._remove_registered_hooks() - - if not keep_factors: - self._release_covariance_matrices() - self._release_eigendecomposition_results() - self._release_lambda_matrix() - self.release_preconditioned_gradient() - self._storage_at_device = False - self.release_scores() - - if mode == ModuleMode.DEFAULT: - pass - elif mode == ModuleMode.COVARIANCE: - self._register_covariance_hooks() - elif mode == ModuleMode.LAMBDA: - self._register_lambda_hooks() - elif mode == ModuleMode.PRECONDITION_GRADIENT: - self._register_precondition_gradient_hooks() - elif mode == ModuleMode.PAIRWISE_SCORE: - self._register_pairwise_score_hooks() - elif mode == ModuleMode.SELF_SCORE: - self._register_self_score_hooks() - elif mode == ModuleMode.SELF_MEASUREMENT_SCORE: - self._register_self_measurement_score_hooks() - else: - raise RuntimeError(f"Unknown module mode {mode}.") - - def _remove_registered_hooks(self) -> None: - """Removes all registered hooks within the module.""" - while self._registered_hooks: - handle = self._registered_hooks.pop() - handle.remove() - self._registered_hooks = [] + """Sets a factor in storage. + + Args: + factor_name (str): + The name of the factor to set. + factor (Any): + The factor value to store. + """ + if factor_name in self.storage: + self.storage[factor_name] = factor + + def set_mode(self, mode: ModuleMode, release_memory: bool = False) -> None: + """Sets the operating mode of the `TrackedModule`. + + This method changes the current mode and manages associated trackers and memory. + + Args: + mode (ModuleMode): + The new mode to set. + release_memory (bool): + Whether to release memory for all trackers. + """ + self._trackers[self.current_mode].release_hooks() + self.einsum_path = None + self.current_mode = mode + + if release_memory: + for _, tracker in self._trackers.items(): + tracker.release_memory() + + self._trackers[self.current_mode].register_hooks() def set_attention_mask(self, attention_mask: Optional[torch.Tensor] = None) -> None: - """Sets the attention mask for activation covariance computations.""" - self._attention_mask = attention_mask + """Sets the attention mask for activation covariance computations. + + Args: + attention_mask (torch.Tensor, optional): + The attention mask to set. + """ + self.attention_mask = attention_mask def set_gradient_scale(self, scale: float = 1.0) -> None: - """Sets the scale of the gradient obtained from `GradScaler`.""" - self._gradient_scale = scale + """Sets the scale of the gradient obtained from `GradScaler`. + + Args: + scale (float): + The scale factor to set. + """ + self.gradient_scale = scale + + def finalize_iteration(self) -> None: + """Finalizes statistics for the current iteration.""" + self._trackers[self.current_mode].finalize_iteration() + + def exist(self) -> bool: + """Checks if the desired statistics are available. + + Returns: + bool: + `True` if statistics exist, `False` otherwise. + """ + return self._trackers[self.current_mode].exist() + + def synchronize(self, num_processes: int) -> None: + """Synchronizes statistics across multiple processes. + + Args: + num_processes (int): + The number of processes to synchronize across. + """ + self._trackers[self.current_mode].synchronize(num_processes=num_processes) + + def truncate(self, keep_size: int) -> None: + """Truncates stored statistics to a specified size. + + Args: + keep_size (int): + The number of dimension to keep. + """ + self._trackers[self.current_mode].truncate(keep_size=keep_size) + + def accumulate_iterations(self) -> None: + """Accumulates (or prepares to accumulate) statistics across multiple iterations.""" + self._trackers[self.current_mode].accumulate_iterations() + + def finalize_all_iterations(self) -> None: + """Finalizes statistics after all iterations.""" + self._trackers[self.current_mode].finalize_all_iterations() - ############################################## - # Methods for computing covariance matrices. # - ############################################## @abstractmethod - def _get_flattened_activation( - self, input_activation: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: """Returns the flattened activation tensor and the number of stacked activations. Args: @@ -202,43 +330,10 @@ def _get_flattened_activation( The flattened activation tensor and the number of stacked activations. The flattened activation is a 2-dimensional matrix with dimension `activation_num x activation_dim`. """ - raise NotImplementedError("Subclasses must implement the `_get_flattened_activation` method.") - - def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -> None: - """Computes and updates the activation covariance matrix. - - Args: - input_activation (torch.Tensor): - The input tensor to the module, provided by the PyTorch's forward hook. - """ - input_activation = input_activation.to(dtype=self.factor_args.activation_covariance_dtype) - flattened_activation, count = self._get_flattened_activation(input_activation=input_activation) - - if self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] is None: - dimension = flattened_activation.size(1) - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] = torch.zeros( - size=(dimension, dimension), - dtype=flattened_activation.dtype, - device=flattened_activation.device, - requires_grad=False, - ) - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) - - if self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] is None: - device = None - if isinstance(count, torch.Tensor): - # When using attention masks, `count` can be a tensor. - device = count.device - self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros( - size=(1,), - dtype=torch.int64, - device=device, - requires_grad=False, - ) - self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count) + raise NotImplementedError("Subclasses must implement the `get_flattened_activation` method.") @abstractmethod - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: """Returns the flattened output gradient tensor. Args: @@ -251,109 +346,30 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch. The flattened output gradient tensor and the number of stacked gradients. The flattened gradient is a 2-dimensional matrix with dimension `gradient_num x gradient_dim`. """ - raise NotImplementedError("Subclasses must implement the `_get_flattened_gradient` method.") + raise NotImplementedError("Subclasses must implement the `get_flattened_gradient` method.") - def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> None: - """Computes and updates the pseudo-gradient covariance matrix. + @abstractmethod + def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: + """Returns the summed gradient tensor. Args: + input_activation (torch.Tensor): + The input tensor to the module, provided by the PyTorch's forward hook. output_gradient (torch.Tensor): - The gradient tensor with respect to the output of the module, provided by the - PyTorch's backward hook. - """ - output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) - flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient) - if self._gradient_scale != 1.0: - flattened_gradient = flattened_gradient * self._gradient_scale - - if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None: - dimension = flattened_gradient.size(1) - self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] = torch.zeros( - size=(dimension, dimension), - dtype=flattened_gradient.dtype, - device=flattened_gradient.device, - requires_grad=False, - ) - self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) - - # This is not necessary as `NUM_GRADIENT_COVARIANCE_PROCESSED` should be identical to - # `NUM_ACTIVATION_COVARIANCE_PROCESSED` in most cases. However, they can be different when using - # gradient checkpointing or torch compile. - if self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None: - self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( - size=(1,), - dtype=torch.int64, - device=count.device if isinstance(count, torch.Tensor) else None, - requires_grad=False, - ) - self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count) - - def _register_covariance_hooks(self) -> None: - """Installs forward and backward hooks for computation of the covariance matrices.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - # Computes and updates activation covariance matrix in the forward pass. - self._update_activation_covariance_matrix(inputs[0].detach().clone()) - # Registers backward hook to obtain gradient with respect to the output. - outputs.register_hook(backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - # Computes and updates pseudo-gradient covariance matrix in the backward pass. - self._update_gradient_covariance_matrix(output_gradient.detach()) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - def _release_covariance_matrices(self) -> None: - """Clears the stored activation and pseudo-gradient covariance matrices from memory.""" - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - del self._storage[covariance_factor_name] - self._storage[covariance_factor_name] = None - - def _covariance_matrices_available(self) -> bool: - """Checks if the covariance matrices are currently stored in the storage.""" - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - if self._storage[covariance_factor_name] is None: - return False - return True - - @torch.no_grad() - def synchronize_covariance_matrices(self) -> None: - """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and torch.cuda.is_available() and self._covariance_matrices_available(): - # Note that only the main process holds the aggregated covariance matrix. - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - self._storage[covariance_factor_name] = self._storage[covariance_factor_name].cuda() - dist.reduce( - tensor=self._storage[covariance_factor_name], - op=dist.ReduceOp.SUM, - dst=0, - ) - - ########################################## - # Methods for computing Lambda matrices. # - ########################################## - def _release_eigendecomposition_results(self) -> None: - """Clears the stored eigenvectors and eigenvalues from memory.""" - for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: - del self._storage[eigen_factor_name] - self._storage[eigen_factor_name] = None + The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. - def _eigendecomposition_results_available(self) -> bool: - """Checks if the eigendecomposition results are currently stored in storage.""" - for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: - if self._storage[eigen_factor_name] is None: - return False - return True + Returns: + torch.Tensor: + The aggregated gradient tensor. + """ + raise NotImplementedError("Subclasses must implement the `compute_summed_gradient` method.") @abstractmethod - def _compute_per_sample_gradient( + def compute_per_sample_gradient( self, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: - """Returns the flattened per-sample-gradient tensor. For a brief introduction to - per-sample-gradients, see https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. + """Returns the flattened per-sample gradient tensor. For a brief introduction to + per-sample gradient, see https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. Args: input_activation (torch.Tensor): @@ -363,658 +379,38 @@ def _compute_per_sample_gradient( Returns: torch.Tensor: - The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x gradient_dim x activation_dim`. + The per-sample gradient tensor. """ - raise NotImplementedError("Subclasses must implement the `_compute_per_sample_gradient` method.") + raise NotImplementedError("Subclasses must implement the `compute_per_sample_gradient` method.") - def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: - """Computes and updates the Lambda matrix using the provided per-sample-gradient. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. - """ - per_sample_gradient = per_sample_gradient.to(self.factor_args.lambda_dtype) - batch_size = per_sample_gradient.size(0) - if self._gradient_scale != 1.0: - per_sample_gradient.mul_(self._gradient_scale) - - if self._storage[LAMBDA_MATRIX_NAME] is None: - # Initializes Lambda matrix if it does not exist. - self._storage[LAMBDA_MATRIX_NAME] = torch.zeros( - size=(per_sample_gradient.size(1), per_sample_gradient.size(2)), - dtype=per_sample_gradient.dtype, - device=per_sample_gradient.device, - requires_grad=False, - ) - - if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: - if not self._eigendecomposition_results_available(): - error_msg = ( - f"The strategy {self.factor_args.strategy} requires Eigendecomposition " - f"results to be loaded for Lambda computations. However, Eigendecomposition " - f"results are not found." - ) - raise FactorsNotFoundError(error_msg) - # Moves activation and pseudo-gradient eigenvectors to appropriate devices. - self._storage[ACTIVATION_EIGENVECTORS_NAME] = self._storage[ACTIVATION_EIGENVECTORS_NAME].to( - dtype=self.factor_args.lambda_dtype, - device=per_sample_gradient.device, - ) - self._storage[GRADIENT_EIGENVECTORS_NAME] = self._storage[GRADIENT_EIGENVECTORS_NAME].to( - dtype=self.factor_args.lambda_dtype, - device=per_sample_gradient.device, - ) - - if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: - if self.factor_args.lambda_iterative_aggregate: - # This batch-wise iterative update can be useful when the GPU memory is limited. - per_sample_gradient = torch.matmul( - per_sample_gradient, - self._storage[ACTIVATION_EIGENVECTORS_NAME], - ) - for i in range(batch_size): - sqrt_lambda = torch.matmul( - self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - per_sample_gradient[i], - ) - self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_()) - else: - per_sample_gradient = torch.matmul( - self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - torch.matmul(per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME]), - ) - self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) - else: - # Approximates the eigenbasis as identity. - self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) - - if self._storage[NUM_LAMBDA_PROCESSED] is None: - self._storage[NUM_LAMBDA_PROCESSED] = torch.zeros( - size=(1,), - dtype=torch.int64, - requires_grad=False, - ) - self._storage[NUM_LAMBDA_PROCESSED].add_(batch_size) - - def _register_lambda_hooks(self) -> None: - """Installs forward and backward hooks for computation of the Lambda matrices.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.factor_args.per_sample_gradient_dtype) - if self.factor_args.cached_activation_cpu_offload: - cached_activation = cached_activation.cpu() - - if self.factor_args.shared_parameters_exist: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - # Registers backward hook to obtain gradient with respect to the output. - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - if self._cached_activations is None: - raise RuntimeError( - f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." - ) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.factor_args.per_sample_gradient_dtype), - ).to(dtype=self.factor_args.lambda_dtype) - del self._cached_activations - self._cached_activations = None - self._update_lambda_matrix(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.factor_args.per_sample_gradient_dtype), - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = per_sample_gradient - else: - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - def _clear_per_sample_gradient_cache(self) -> None: - """Clears all caches from per-sample-gradient computations.""" - del self._cached_per_sample_gradient - self._cached_per_sample_gradient = None - del self._cached_activations - self._cached_activations = None - - @torch.no_grad() - def finalize_lambda_matrix(self) -> None: - """Computes and updates the Lambda matrix using the cached per-sample-gradient.""" - self._update_lambda_matrix( - per_sample_gradient=self._cached_per_sample_gradient.to(dtype=self.factor_args.lambda_dtype) - ) - self._clear_per_sample_gradient_cache() - - def _release_lambda_matrix(self) -> None: - """Clears the stored Lambda matrix from memory.""" - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - del self._storage[lambda_factor_name] - self._storage[lambda_factor_name] = None - self._clear_per_sample_gradient_cache() - - def _lambda_matrix_available(self) -> bool: - """Checks if the Lamda matrix is currently stored in storage.""" - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - if self._storage[lambda_factor_name] is None: - return False - return True - - @torch.no_grad() - def synchronize_lambda_matrices(self) -> None: - """Aggregates Lambda matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and torch.cuda.is_available() and self._lambda_matrix_available(): - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - self._storage[lambda_factor_name] = self._storage[lambda_factor_name].cuda() - torch.distributed.reduce( - tensor=self._storage[lambda_factor_name], - op=dist.ReduceOp.SUM, - dst=0, - ) - - ################################################## - # Methods for computing preconditioned gradient. # - ################################################## - def _compute_low_rank_preconditioned_gradient( - self, - preconditioned_gradient: torch.Tensor, - ) -> List[torch.Tensor]: - """Performs low-rank approximation of the preconditioned gradient with SVD. + @abstractmethod + def compute_pairwise_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + """Computes pairwise influence scores. Args: preconditioned_gradient (torch.Tensor): - The preconditioned per-sample-gradient matrix to be low-rank approximated. - - Returns: - List[torch.Tensor, torch.Tensor]: - Low-rank matrices that approximate the original preconditioned query gradient. - """ - U, S, V = torch.linalg.svd( # pylint: disable=not-callable - preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype), - full_matrices=False, - ) - rank = self.score_args.query_gradient_rank - U_k = U[:, :, :rank] - S_k = S[:, :rank] - # Avoids holding the full memory of the original tensor before indexing. - V_k = V[:, :rank, :].contiguous().clone() - return [ - torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=self.score_args.score_dtype).contiguous().clone(), - V_k.to(dtype=self.score_args.score_dtype), - ] - - def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the preconditioned per-sample-gradient. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. - """ - per_sample_gradient = per_sample_gradient.to(dtype=self.score_args.precondition_dtype) - if self._gradient_scale != 1.0: - per_sample_gradient.mul_(self._gradient_scale) - - preconditioned_gradient = FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( - gradient=per_sample_gradient, - storage=self._storage, - damping=self.score_args.damping, - ) - del per_sample_gradient - - if ( - self.score_args.query_gradient_rank is not None - and min(preconditioned_gradient.size()[1:]) > self.score_args.query_gradient_rank - ): - # Applies low-rank approximation to the preconditioned gradient. - preconditioned_gradient = self._compute_low_rank_preconditioned_gradient( - preconditioned_gradient=preconditioned_gradient - ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient - else: - self._storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient.to(dtype=self.score_args.score_dtype) - - def _register_precondition_gradient_hooks(self) -> None: - """Installs forward and backward hooks for computation of preconditioned per-sample-gradient.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.cached_activation_cpu_offload: - cached_activation = cached_activation.cpu() - - if self.factor_args.shared_parameters_exist: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - if self._cached_activations is None: - raise RuntimeError( - f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." - ) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.precondition_dtype) - del self._cached_activations - self._cached_activations = None - self._compute_preconditioned_gradient(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = per_sample_gradient - else: - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_preconditioned_gradient(self) -> None: - """Computes the aggregated preconditioned gradient using the cached per-sample-gradient.""" - self._compute_preconditioned_gradient( - per_sample_gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype) - ) - self._clear_per_sample_gradient_cache() - - @torch.no_grad() - def accumulate_preconditioned_gradient(self) -> None: - """Accumulates the preconditioned per-sample-gradients computed over different batches.""" - if self._storage[PRECONDITIONED_GRADIENT_NAME] is None: - return - - accumulated_gradient = self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] - gradient = self._storage[PRECONDITIONED_GRADIENT_NAME] - - if self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] is None: - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = gradient - else: - if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = [ - torch.cat((accumulated_gradient[0], gradient[0]), dim=0).contiguous(), - torch.cat((accumulated_gradient[1], gradient[1]), dim=0).contiguous(), - ] - else: - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat( - (accumulated_gradient, gradient), dim=0 - ).contiguous() - del self._storage[PRECONDITIONED_GRADIENT_NAME] - self._storage[PRECONDITIONED_GRADIENT_NAME] = None - - def release_preconditioned_gradient(self) -> None: - """Clears the preconditioned per-sample-gradient from memory.""" - del self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None - del self._storage[PRECONDITIONED_GRADIENT_NAME] - self._storage[PRECONDITIONED_GRADIENT_NAME] = None - self._clear_per_sample_gradient_cache() - - @torch.no_grad() - def truncate_preconditioned_gradient(self, keep_size: int) -> None: - """Truncates and keeps only the first keep_size dimension for the preconditioned gradient.""" - if self._storage[PRECONDITIONED_GRADIENT_NAME] is not None: - if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - assert len(self._storage[PRECONDITIONED_GRADIENT_NAME]) == 2 - self._storage[PRECONDITIONED_GRADIENT_NAME] = [ - self._storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size].clone(), - self._storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size].clone(), - ] - else: - self._storage[PRECONDITIONED_GRADIENT_NAME] = self._storage[PRECONDITIONED_GRADIENT_NAME][ - :keep_size - ].clone() - - @torch.no_grad() - def synchronize_preconditioned_gradient(self, num_processes: int) -> None: - """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" - if ( - dist.is_initialized() - and torch.cuda.is_available() - and self._storage[PRECONDITIONED_GRADIENT_NAME] is not None - ): - if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])): - size = self._storage[PRECONDITIONED_GRADIENT_NAME][i].size() - stacked_matrix = torch.empty( - size=(num_processes, size[0], size[1], size[2]), - dtype=self._storage[PRECONDITIONED_GRADIENT_NAME][i].dtype, - device=self._storage[PRECONDITIONED_GRADIENT_NAME][i].device, - ) - torch.distributed.all_gather_into_tensor( - output_tensor=stacked_matrix, - input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), - ) - self._storage[PRECONDITIONED_GRADIENT_NAME][i] = ( - stacked_matrix.transpose(0, 1) - .reshape(num_processes * size[0], size[1], size[2]) - .contiguous() - .clone() - ) - else: - size = self._storage[PRECONDITIONED_GRADIENT_NAME].size() - stacked_preconditioned_gradient = torch.empty( - size=(num_processes, size[0], size[1], size[2]), - dtype=self._storage[PRECONDITIONED_GRADIENT_NAME].dtype, - device=self._storage[PRECONDITIONED_GRADIENT_NAME].device, - ) - torch.distributed.all_gather_into_tensor( - output_tensor=stacked_preconditioned_gradient, - input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), - ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = ( - stacked_preconditioned_gradient.transpose(0, 1) - .reshape(num_processes * size[0], size[1], size[2]) - .contiguous() - .clone() - ) - - ########################################### - # Methods for computing influence scores. # - ########################################### - def _compute_pairwise_score(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the pairwise influence scores. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. - """ - per_sample_gradient = per_sample_gradient.to(dtype=self.score_args.score_dtype) - if self._gradient_scale != 1.0: - per_sample_gradient.mul_(self._gradient_scale) - - if isinstance(self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], list): - # The preconditioned gradient is stored as a low-rank approximation. - left_mat, right_mat = self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] - input_dim = right_mat.size(2) - output_dim = left_mat.size(1) - query_batch_size = left_mat.size(0) - train_batch_size = per_sample_gradient.size(0) - rank = self.score_args.query_gradient_rank - if ( - train_batch_size * query_batch_size * rank * min((input_dim, output_dim)) - > query_batch_size * input_dim * output_dim - ): - # If reconstructing the gradient is more memory efficient, reconstructs and computes the score. - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( - "qki,toi,qok->qt", - right_mat, - per_sample_gradient, - left_mat, - ) - # Otherwise, tries to avoid reconstructing the full per-sample-gradient. - elif output_dim >= input_dim: - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( - "qki,qtik->qt", right_mat, contract("toi,qok->qtik", per_sample_gradient, left_mat) - ) - else: - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( - "qtko,qok->qt", contract("qki,toi->qtko", right_mat, per_sample_gradient), left_mat - ) - else: - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( - "qio,tio->qt", - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], - per_sample_gradient, - ) - - def _register_pairwise_score_hooks(self) -> None: - """Installs forward and backward hooks for computation of pairwise influence scores.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.cached_activation_cpu_offload: - cached_activation = cached_activation.cpu() - - if self.factor_args.shared_parameters_exist: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - if self._cached_activations is None: - raise RuntimeError( - f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid the error." - ) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.score_dtype) - del self._cached_activations - self._cached_activations = None - self._compute_pairwise_score(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = per_sample_gradient - else: - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_pairwise_score(self) -> None: - """Computes the pairwise influence scores using the cached per-sample-gradient.""" - self._compute_pairwise_score( - per_sample_gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.score_dtype) - ) - self._clear_per_sample_gradient_cache() - - def _compute_self_score(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the self-influence scores. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. + The preconditioned gradient. + input_activation (torch.Tensor): + The input tensor to the module, provided by the PyTorch's forward hook. + output_gradient (torch.Tensor): + The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. """ - if self._gradient_scale != 1.0: - per_sample_gradient.mul_(self._gradient_scale) - - if not self._storage_at_device: - self._move_storage_to_device( - target_device=per_sample_gradient.device, target_dtype=self.score_args.precondition_dtype - ) - self._storage_at_device = True - preconditioned_gradient = ( - FactorConfig.CONFIGS[self.factor_args.strategy] - .precondition_gradient( - gradient=per_sample_gradient.to(dtype=self.score_args.precondition_dtype), - storage=self._storage, - damping=self.score_args.damping, - ) - .to(dtype=self.score_args.score_dtype) - ) - preconditioned_gradient.mul_(per_sample_gradient.to(dtype=self.score_args.score_dtype)) - self._storage[SELF_SCORE_VECTOR_NAME] = preconditioned_gradient.sum(dim=(1, 2)) - - def _register_self_score_hooks(self) -> None: - """Installs forward and backward hooks for computation of self-influence scores.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.cached_activation_cpu_offload: - cached_activation = cached_activation.cpu() - - if self.factor_args.shared_parameters_exist: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - if self._cached_activations is None: - raise RuntimeError( - f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." - ) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.precondition_dtype) - del self._cached_activations - self._cached_activations = None - self._compute_self_score(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = per_sample_gradient - else: - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) + raise NotImplementedError("Subclasses must implement the `compute_pairwise_score` method.") - def finalize_self_score(self) -> None: - """Computes the self-influence scores using the cached per-sample-gradient.""" - self._compute_self_score(per_sample_gradient=self._cached_per_sample_gradient) - self._clear_per_sample_gradient_cache() - - def _compute_self_measurement_score(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the self-influence scores with measurement. + @abstractmethod + def compute_self_measurement_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + """Computes self-influence scores with measurement. Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. + preconditioned_gradient (torch.Tensor): + The preconditioned gradient. + input_activation (torch.Tensor): + The input tensor to the module, provided by the PyTorch's forward hook. + output_gradient (torch.Tensor): + The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. """ - per_sample_gradient = per_sample_gradient.to(dtype=self.score_args.score_dtype) - if self._gradient_scale != 1.0: - per_sample_gradient.mul_(self._gradient_scale) - if not self._storage_at_device: - self._move_storage_to_device( - target_device=per_sample_gradient.device, target_dtype=self.score_args.precondition_dtype - ) - self._storage_at_device = True - self._storage[SELF_SCORE_VECTOR_NAME] = per_sample_gradient.mul_( - self._storage[PRECONDITIONED_GRADIENT_NAME] - ).sum(dim=(1, 2)) - del self._storage[PRECONDITIONED_GRADIENT_NAME] - self._storage[PRECONDITIONED_GRADIENT_NAME] = None - - def _register_self_measurement_score_hooks(self) -> None: - """Installs forward and backward hooks for computation of self-influence scores.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.cached_activation_cpu_offload: - cached_activation = cached_activation.cpu() - - if self.factor_args.shared_parameters_exist: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - if self._cached_activations is None: - raise RuntimeError( - f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." - ) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.score_dtype) - del self._cached_activations - self._cached_activations = None - self._compute_self_measurement_score(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = per_sample_gradient - else: - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_self_measurement_score(self) -> None: - """Computes the self-influence scores with measurement using the cached per-sample-gradient.""" - self._compute_self_measurement_score( - per_sample_gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.score_dtype) - ) - self._clear_per_sample_gradient_cache() - - def _move_storage_to_device(self, target_device: torch.device, target_dtype: torch.dtype) -> None: - """Moves stored factors into the target device.""" - for name, factor in self._storage.items(): - if factor is not None: - if isinstance(factor, list): - for i in range(len(self._storage[name])): - self._storage[name][i] = factor[i].to( - device=target_device, - dtype=target_dtype, - ) - else: - self._storage[name] = factor.to(device=target_device, dtype=target_dtype) - - def release_scores(self) -> None: - """Clears the influence scores from memory.""" - del self._storage[PAIRWISE_SCORE_MATRIX_NAME] - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = None - del self._storage[SELF_SCORE_VECTOR_NAME] - self._storage[SELF_SCORE_VECTOR_NAME] = None - self._clear_per_sample_gradient_cache() + raise NotImplementedError("Subclasses must implement the `compute_self_measurement_score` method.") diff --git a/kronfluence/module/tracker/__init__.py b/kronfluence/module/tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kronfluence/module/tracker/base.py b/kronfluence/module/tracker/base.py new file mode 100644 index 0000000..c6f603f --- /dev/null +++ b/kronfluence/module/tracker/base.py @@ -0,0 +1,110 @@ +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.utils.hooks import RemovableHandle + + +class BaseTracker: + """Base class for tracking module activations, gradients, and scores.""" + + def __init__(self, module: nn.Module) -> None: + """Initializes an instance of the `BaseTracker` class. + + Args: + module (TrackedModule): + The `TrackedModule` that wraps the original module. + """ + self.module = module + self.registered_hooks: List[RemovableHandle] = [] + self.cached_hooks: List[RemovableHandle] = [] + self.cached_activations: Optional[Union[List[torch.Tensor]], torch.Tensor] = None + self.cached_per_sample_gradient: Optional[torch.Tensor] = None + + def release_hooks(self) -> None: + """Removes all registered hooks.""" + self.clear_all_cache() + while self.registered_hooks: + handle = self.registered_hooks.pop() + handle.remove() + self.registered_hooks = [] + + def clear_all_cache(self) -> None: + """Clears all cached data and removes cached hooks.""" + del self.cached_activations, self.cached_per_sample_gradient + self.cached_activations, self.cached_per_sample_gradient = None, None + while self.cached_hooks: + handle = self.cached_hooks.pop() + handle.remove() + self.cached_hooks = [] + + def _raise_cache_not_found_exception(self) -> None: + """Raises an exception when cached activations are not found.""" + raise RuntimeError( + f"Module '{self.module.name}' has no cached activations. This can occur if:\n" + f"1. The module was not used during the forward pass, or\n" + f"2. The module was encountered multiple times in the forward pass.\n" + f"For case 2, set 'has_shared_parameters=True' to enable parameter sharing." + ) + + def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + """Preprocesses the output gradient. + + Args: + output_gradient (torch.Tensor): + The original output gradient. + target_dtype (torch.dtype): + The desired data type for the gradient tensor. + + Returns: + torch.Tensor: + The preprocessed gradient. + """ + original_dtype = output_gradient.dtype + output_gradient = output_gradient.to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale + return output_gradient + + def register_hooks(self) -> None: + """Registers hooks for the module.""" + + def finalize_iteration(self) -> None: + """Finalizes statistics for the current iteration.""" + + def exist(self) -> bool: + """Checks if the desired statistics are available. + + Returns: + bool: + `True` if statistics exist, `False` otherwise. + """ + return False + + def synchronize(self, num_processes: int) -> None: + """Synchronizes statistics across multiple processes. + + Args: + num_processes (int): + The number of processes to synchronize across. + """ + + def truncate(self, keep_size: int) -> None: + """Truncates stored statistics to a specified size. + + Args: + keep_size (int): + The number of dimensions to keep. + """ + + def accumulate_iterations(self) -> None: + """Accumulates (or prepares to accumulate) statistics across multiple iterations.""" + + def finalize_all_iterations(self) -> None: + """Finalizes statistics after all iterations.""" + + def release_memory(self) -> None: + """Releases any memory held by the tracker.""" diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py new file mode 100644 index 0000000..c649f97 --- /dev/null +++ b/kronfluence/module/tracker/factor.py @@ -0,0 +1,326 @@ +from typing import Tuple, Union + +import torch +import torch.distributed as dist +from torch import nn + +from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACTIVATION_COVARIANCE_MATRIX_NAME, + ACTIVATION_EIGENVECTORS_NAME, + COVARIANCE_FACTOR_NAMES, + EIGENDECOMPOSITION_FACTOR_NAMES, + GRADIENT_COVARIANCE_MATRIX_NAME, + GRADIENT_EIGENVECTORS_NAME, + LAMBDA_FACTOR_NAMES, + LAMBDA_MATRIX_NAME, + NUM_ACTIVATION_COVARIANCE_PROCESSED, + NUM_GRADIENT_COVARIANCE_PROCESSED, + NUM_LAMBDA_PROCESSED, +) +from kronfluence.utils.exceptions import FactorsNotFoundError + + +class CovarianceTracker(BaseTracker): + """Tracks and computes activation and gradient covariance matrices for a given module.""" + + _activation_covariance_initialized: bool = False + _gradient_covariance_initialized: bool = False + + def _update_activation_covariance_matrix( + self, input_activation: torch.Tensor, count: Union[torch.Tensor, int] + ) -> None: + """Computes and updates the activation covariance matrix. + + Args: + input_activation (torch.Tensor): + The flattened input tensor to the module, provided by PyTorch's forward hook. + count (int): + The number of activations. + """ + if not self._activation_covariance_initialized: + self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + device=count.device if isinstance(count, torch.Tensor) else None, + requires_grad=False, + ) + dimension = input_activation.size(1) + self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME] = torch.zeros( + size=(dimension, dimension), + dtype=input_activation.dtype, + device=input_activation.device, + requires_grad=False, + ) + self._activation_covariance_initialized = True + self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count) + self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(input_activation.t(), input_activation) + + def _update_gradient_covariance_matrix( + self, output_gradient: torch.Tensor, count: Union[torch.Tensor, int] + ) -> None: + """Computes and updates the pseudo-gradient covariance matrix. + + Args: + output_gradient (torch.Tensor): + The flattened gradient tensor with respect to the output of the module, provided + by PyTorch's backward hook. + count (int): + The number of gradients. + """ + if not self._gradient_covariance_initialized: + # In most cases, `NUM_GRADIENT_COVARIANCE_PROCESSED` and `NUM_ACTIVATION_COVARIANCE_PROCESSED` are + # identical. However, they may differ when using gradient checkpointing or `torch.compile()`. + self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + device=count.device if isinstance(count, torch.Tensor) else None, + requires_grad=False, + ) + dimension = output_gradient.size(1) + self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME] = torch.zeros( + size=(dimension, dimension), + dtype=output_gradient.dtype, + device=output_gradient.device, + requires_grad=False, + ) + self._gradient_covariance_initialized = True + self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count) + self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient) + + def register_hooks(self) -> None: + """Sets up hooks to compute activation and gradient covariance matrices.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + input_activation = ( + inputs[0] + .detach() + .to( + dtype=self.module.factor_args.activation_covariance_dtype, + copy=self.module.attention_mask is not None, + ) + ) + # Computes and updates activation covariance during forward pass. + input_activation, count = self.module.get_flattened_activation(input_activation=input_activation) + self._update_activation_covariance_matrix(input_activation=input_activation, count=count) + self.cached_hooks.append(outputs.register_hook(backward_hook)) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.factor_args.gradient_covariance_dtype + ) + # Computes and updates pseudo-gradient covariance during backward pass. + output_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient) + self._update_gradient_covariance_matrix(output_gradient=output_gradient, count=count) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + def exist(self) -> bool: + """Checks if both activation and gradient covariance matrices are available.""" + for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + if self.module.storage[covariance_factor_name] is None: + return False + return True + + def synchronize(self, num_processes: int) -> None: + """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" + del num_processes + if dist.is_initialized() and torch.cuda.is_available() and self.exist(): + for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + self.module.storage[covariance_factor_name] = self.module.storage[covariance_factor_name].cuda() + dist.reduce( + tensor=self.module.storage[covariance_factor_name], + op=dist.ReduceOp.SUM, + dst=0, + ) + + def release_memory(self) -> None: + """Clears all covariance matrices from memory.""" + self._activation_covariance_initialized = False + self._gradient_covariance_initialized = False + for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + self.module.storage[covariance_factor_name] = None + + +class LambdaTracker(BaseTracker): + """Tracks and computes Lambda matrices for a given module.""" + + def _eigendecomposition_results_exist(self) -> bool: + """Checks if eigendecomposition results are available.""" + for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: + if self.module.storage[eigen_factor_name] is None: + return False + return True + + def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: + """Computes and updates the Lambda matrix using provided per-sample gradient. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample gradient tensor for the given batch. + """ + batch_size = per_sample_gradient.size(0) + + if self.module.storage[NUM_LAMBDA_PROCESSED] is None: + self.module.storage[NUM_LAMBDA_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + requires_grad=False, + ) + self.module.storage[LAMBDA_MATRIX_NAME] = torch.zeros( + size=(per_sample_gradient.size(1), per_sample_gradient.size(2)), + dtype=per_sample_gradient.dtype, + device=per_sample_gradient.device, + requires_grad=False, + ) + + if FactorConfig.CONFIGS[self.module.factor_args.strategy].requires_eigendecomposition_for_lambda: + if not self._eigendecomposition_results_exist(): + raise FactorsNotFoundError( + f"The strategy {self.module.factor_args.strategy} requires eigendecomposition " + f"results for Lambda computations, but they are not found." + ) + + # Move activation and pseudo-gradient eigenvectors to appropriate devices. + self.module.storage[ACTIVATION_EIGENVECTORS_NAME] = self.module.storage[ + ACTIVATION_EIGENVECTORS_NAME + ].to( + dtype=per_sample_gradient.dtype, + device=per_sample_gradient.device, + ) + self.module.storage[GRADIENT_EIGENVECTORS_NAME] = self.module.storage[GRADIENT_EIGENVECTORS_NAME].to( + dtype=per_sample_gradient.dtype, + device=per_sample_gradient.device, + ) + + self.module.storage[NUM_LAMBDA_PROCESSED].add_(batch_size) + if FactorConfig.CONFIGS[self.module.factor_args.strategy].requires_eigendecomposition_for_lambda: + if self.module.factor_args.use_iterative_lambda_aggregation: + # This batch-wise iterative update can be useful when the GPU memory is limited. + per_sample_gradient = torch.matmul( + per_sample_gradient, + self.module.storage[ACTIVATION_EIGENVECTORS_NAME], + ) + for i in range(batch_size): + sqrt_lambda = torch.matmul( + self.module.storage[GRADIENT_EIGENVECTORS_NAME].t(), + per_sample_gradient[i], + ) + self.module.storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_()) + else: + per_sample_gradient = ( + torch.matmul( + self.module.storage[GRADIENT_EIGENVECTORS_NAME].t(), + torch.matmul(per_sample_gradient, self.module.storage[ACTIVATION_EIGENVECTORS_NAME]), + ) + .square_() + .sum(dim=0) + ) + self.module.storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient) + else: + # Approximate the eigenbasis as identity. + per_sample_gradient = per_sample_gradient.square_().sum(dim=0) + self.module.storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient) + + def register_hooks(self) -> None: + """Sets up hooks to compute lambda matrices.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.factor_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append( + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype + ) + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).to(dtype=self.module.factor_args.lambda_dtype) + self.clear_all_cache() + del output_gradient + # Computes and updates lambda matrix during backward pass. + self._update_lambda_matrix(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + # Aggregates per-sample gradients during backward pass. + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Updates Lambda matrix using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self.cached_per_sample_gradient = self.cached_per_sample_gradient.to( + dtype=self.module.factor_args.lambda_dtype + ) + self._update_lambda_matrix(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if Lambda matrices are available.""" + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + if self.module.storage[lambda_factor_name] is None: + return False + return True + + def synchronize(self, num_processes: int) -> None: + """Aggregates Lambda matrices across multiple devices or nodes in a distributed setting.""" + del num_processes + if dist.is_initialized() and torch.cuda.is_available() and self.exist(): + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + self.module.storage[lambda_factor_name] = self.module.storage[lambda_factor_name].cuda() + dist.reduce( + tensor=self.module.storage[lambda_factor_name], + op=dist.ReduceOp.SUM, + dst=0, + ) + + def release_memory(self) -> None: + """Clears Lambda matrices from memory.""" + self.clear_all_cache() + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + self.module.storage[lambda_factor_name] = None diff --git a/kronfluence/module/tracker/gradient.py b/kronfluence/module/tracker/gradient.py new file mode 100644 index 0000000..8f49fa0 --- /dev/null +++ b/kronfluence/module/tracker/gradient.py @@ -0,0 +1,93 @@ +from typing import Tuple + +import torch +import torch.distributed as dist +from torch import nn + +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import AGGREGATED_GRADIENT_NAME + + +class GradientTracker(BaseTracker): + """Tracks and computes aggregated gradient for a given module.""" + + def register_hooks(self) -> None: + """Sets up hooks to compute and keep track of aggregated gradient.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append(outputs.register_hook(backward_hook)) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + if isinstance(self.cached_activations, list): + cached_activation = self.cached_activations.pop() + else: + cached_activation = self.cached_activations + if self.module.per_sample_gradient_process_fnc is None: + summed_gradient = self.module.compute_summed_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + self.clear_all_cache() + else: + summed_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ).sum(dim=0, keepdim=True) + if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False) + self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + def finalize_iteration(self): + """Clears all cached data from memory.""" + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if aggregated gradient is available.""" + return self.module.storage[AGGREGATED_GRADIENT_NAME] is not None + + def synchronize(self, num_processes: int = 1) -> None: + """Aggregates summed gradient across multiple devices or nodes in a distributed setting.""" + del num_processes + if dist.is_initialized() and torch.cuda.is_available(): + if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros( + size=(1,), + dtype=self.module.score_args.per_sample_gradient_dtype, + device="cuda", + requires_grad=False, + ) + self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].contiguous() + dist.all_reduce( + tensor=self.module.storage[AGGREGATED_GRADIENT_NAME], + op=dist.ReduceOp.SUM, + ) + + def release_memory(self) -> None: + """Clears aggregated gradients from memory.""" + self.clear_all_cache() + self.module.storage[AGGREGATED_GRADIENT_NAME] = None diff --git a/kronfluence/module/tracker/pairwise_score.py b/kronfluence/module/tracker/pairwise_score.py new file mode 100644 index 0000000..b46ff67 --- /dev/null +++ b/kronfluence/module/tracker/pairwise_score.py @@ -0,0 +1,135 @@ +from typing import Tuple + +import torch +from opt_einsum import DynamicProgramming, contract_path +from torch import _VF, nn + +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, + AGGREGATED_GRADIENT_NAME, + PAIRWISE_SCORE_MATRIX_NAME, + PRECONDITIONED_GRADIENT_NAME, +) + + +class PairwiseScoreTracker(BaseTracker): + """Computes pairwise influence scores for a given module.""" + + def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tensor) -> None: + """Computes pairwise influence scores using per-sample-gradient. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample-gradient tensor for the given batch. + """ + precondition_name = ACCUMULATED_PRECONDITIONED_GRADIENT_NAME + if isinstance(self.module.storage[precondition_name], list): + left_mat, right_mat = self.module.storage[precondition_name] + expr = "qki,toi,qok->qt" + if self.module.einsum_path is None: + path = contract_path( + expr, + right_mat, + per_sample_gradient, + left_mat, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + )[0] + self.module.einsum_path = [item for pair in path for item in pair] + scores = _VF.einsum(expr, (right_mat, per_sample_gradient, left_mat), path=self.module.einsum_path) # pylint: disable=no-member + else: + scores = torch.einsum( + "qio,tio->qt", + self.module.storage[precondition_name], + per_sample_gradient, + ) + + if self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] is not None: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME].add_(scores) + else: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = scores + + def register_hooks(self) -> None: + """Sets up hooks to compute pairwise influence scores.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append(outputs.register_hook(backward_hook)) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.score_dtype + ) + if isinstance(self.cached_activations, list): + cached_activation = self.cached_activations.pop() + else: + cached_activation = self.cached_activations + # Computes pairwise influence scores during backward pass. + if self.module.per_sample_gradient_process_fnc is None: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = self.module.compute_pairwise_score( + preconditioned_gradient=self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + del cached_activation, output_gradient + self.clear_all_cache() + else: + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + del cached_activation, output_gradient + self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + def finalize_iteration(self) -> None: + """Clears all cached data from memory.""" + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if pairwise score is available.""" + return self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] is not None + + def accumulate_iterations(self) -> None: + """Removes pairwise scores from memory after a single iteration.""" + self.release_memory() + + @torch.no_grad() + def finalize_all_iterations(self) -> None: + """Removes cached preconditioned gradient from memory. Additionally, if aggregated gradients are available, + computes the pairwise score using them.""" + if self.module.storage[AGGREGATED_GRADIENT_NAME] is not None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].to( + dtype=self.module.score_args.score_dtype + ) + self._compute_pairwise_score_with_gradient( + per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME] + ) + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + self.clear_all_cache() + + def release_memory(self) -> None: + """Releases pairwise scores from memory.""" + self.clear_all_cache() + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = None diff --git a/kronfluence/module/tracker/precondition.py b/kronfluence/module/tracker/precondition.py new file mode 100644 index 0000000..1d60dbc --- /dev/null +++ b/kronfluence/module/tracker/precondition.py @@ -0,0 +1,261 @@ +from typing import List, Tuple + +import torch +import torch.distributed as dist +from torch import nn + +from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, + AGGREGATED_GRADIENT_NAME, + PRECONDITIONED_GRADIENT_NAME, +) + + +class PreconditionTracker(BaseTracker): + """Computes preconditioned gradient for a given module.""" + + def _compute_low_rank_preconditioned_gradient( + self, + preconditioned_gradient: torch.Tensor, + target_dtype: torch.dtype, + ) -> List[torch.Tensor]: + """Performs low-rank approximation of the preconditioned gradient. + + Args: + preconditioned_gradient (torch.Tensor): + The preconditioned per-sample gradient tensor to be low-rank approximated. + target_dtype (torch.dtype): + The desired dtype for the output. + + Returns: + List[torch.Tensor, torch.Tensor]: + Low-rank matrices approximating the original preconditioned gradient. + """ + rank = self.module.score_args.query_gradient_low_rank + if self.module.score_args.use_full_svd: + U, S, V = torch.linalg.svd( # pylint: disable=not-callable + preconditioned_gradient, + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + # Avoid holding the full memory of the original tensor before indexing. + V_k = V[:, :rank, :].to(dtype=target_dtype, copy=True) + left_mat = torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=target_dtype) + return [left_mat, V_k] + + U, S, V = torch.svd_lowrank(preconditioned_gradient, q=rank) + left_mat = torch.matmul(U, torch.diag_embed(S)).to(dtype=target_dtype) + V = V.transpose(1, 2).to(dtype=target_dtype) + return [left_mat, V] + + def _process_preconditioned_gradient(self, preconditioned_gradient: torch.Tensor) -> None: + """Processes the preconditioned per-sample gradient. + + Args: + preconditioned_gradient (torch.Tensor): + The preconditioned per-sample gradient tensor for the given batch. + """ + if ( + self.module.score_args.query_gradient_low_rank is not None + and min(preconditioned_gradient.size()[1:]) > self.module.score_args.query_gradient_low_rank + ): + # Apply low-rank approximation to the preconditioned gradient. + preconditioned_gradient = preconditioned_gradient.to( + dtype=self.module.score_args.query_gradient_svd_dtype + ).contiguous() + preconditioned_gradient = self._compute_low_rank_preconditioned_gradient( + preconditioned_gradient=preconditioned_gradient, + target_dtype=self.module.score_args.score_dtype, + ) + else: + preconditioned_gradient = preconditioned_gradient.to(dtype=self.module.score_args.score_dtype) + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient + + def register_hooks(self) -> None: + """Sets up hooks to compute preconditioned per-sample gradient.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append( + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).to(dtype=self.module.score_args.precondition_dtype) + self.clear_all_cache() + del output_gradient + # Computes preconditioned per-sample gradient during backward pass. + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=per_sample_gradient, + storage=self.module.storage, + ) + del per_sample_gradient + self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + # Aggregates per-sample gradients during backward pass. + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Computes preconditioned gradient using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self.cached_per_sample_gradient = self.cached_per_sample_gradient.to( + dtype=self.module.score_args.precondition_dtype + ) + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=self.cached_per_sample_gradient, + storage=self.module.storage, + ) + self.cached_per_sample_gradient = None + self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if preconditioned gradient is available.""" + return ( + self.module.storage[PRECONDITIONED_GRADIENT_NAME] is not None + or self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] is not None + ) + + def synchronize(self, num_processes: int = 1) -> None: + """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" + if ( + dist.is_initialized() + and torch.cuda.is_available() + and self.module.storage[PRECONDITIONED_GRADIENT_NAME] is not None + ): + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + for i in range(len(self.module.storage[PRECONDITIONED_GRADIENT_NAME])): + size = self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].size() + stacked_matrix = torch.empty( + size=(num_processes, size[0], size[1], size[2]), + dtype=self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].dtype, + device=self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].device, + ) + torch.distributed.all_gather_into_tensor( + output_tensor=stacked_matrix, + input_tensor=self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), + ) + self.module.storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).reshape( + num_processes * size[0], size[1], size[2] + ) + else: + size = self.module.storage[PRECONDITIONED_GRADIENT_NAME].size() + stacked_preconditioned_gradient = torch.empty( + size=(num_processes, size[0], size[1], size[2]), + dtype=self.module.storage[PRECONDITIONED_GRADIENT_NAME].dtype, + device=self.module.storage[PRECONDITIONED_GRADIENT_NAME].device, + ) + torch.distributed.all_gather_into_tensor( + output_tensor=stacked_preconditioned_gradient, + input_tensor=self.module.storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), + ) + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose( + 0, 1 + ).reshape(num_processes * size[0], size[1], size[2]) + + def truncate(self, keep_size: int) -> None: + """Truncates preconditioned gradient to appropriate dimension.""" + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + assert len(self.module.storage[PRECONDITIONED_GRADIENT_NAME]) == 2 + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = [ + self.module.storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size].clone(), + self.module.storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size].clone(), + ] + else: + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = self.module.storage[PRECONDITIONED_GRADIENT_NAME][ + :keep_size + ].clone() + + def accumulate_iterations(self) -> None: + """Accumulates preconditioned gradient across multiple iterations.""" + accumulated_gradient = self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] + gradient = self.module.storage[PRECONDITIONED_GRADIENT_NAME] + + if self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] is None: + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = [ + tensor.contiguous() for tensor in gradient + ] + else: + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = gradient.contiguous() + + else: + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = [ + torch.cat((accumulated_gradient[0], gradient[0]), dim=0).contiguous(), + torch.cat((accumulated_gradient[1], gradient[1]), dim=0).contiguous(), + ] + else: + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat( + (accumulated_gradient, gradient), dim=0 + ).contiguous() + del self.module.storage[PRECONDITIONED_GRADIENT_NAME], gradient + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + + @torch.no_grad() + def finalize_all_iterations(self) -> None: + """Preconditions aggregated gradient if it exists in storage.""" + if self.module.storage[AGGREGATED_GRADIENT_NAME] is not None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].to( + dtype=self.module.score_args.precondition_dtype + ) + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=self.module.storage[AGGREGATED_GRADIENT_NAME], + storage=self.module.storage, + ) + self.module.storage[AGGREGATED_GRADIENT_NAME] = None + self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient) + self.accumulate_iterations() + + def release_memory(self) -> None: + """Clears preconditioned gradients from memory.""" + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + self.clear_all_cache() diff --git a/kronfluence/module/tracker/self_score.py b/kronfluence/module/tracker/self_score.py new file mode 100644 index 0000000..c8230dd --- /dev/null +++ b/kronfluence/module/tracker/self_score.py @@ -0,0 +1,253 @@ +from typing import Tuple + +import torch +from torch import nn + +from kronfluence.factor.config import STORAGE_TYPE, FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + PRECONDITIONED_GRADIENT_NAME, + SELF_SCORE_VECTOR_NAME, +) + + +def move_storage_to_device(storage: STORAGE_TYPE, target_device: torch.device) -> None: + """Moves all stored factors in the storage dictionary to the specified target device. + + Args: + storage (STORAGE_TYPE): + A dictionary containing stored factors. + target_device (torch.device): + The target device to move the factors to. + """ + for name, factor in storage.items(): + if factor is not None: + if isinstance(factor, list): + for i in range(len(storage[name])): + storage[name][i] = factor[i].to(device=target_device) + if isinstance(factor, torch.Tensor): + storage[name] = factor.to(device=target_device) + + +class SelfScoreTracker(BaseTracker): + """Computes self-influence scores for a given module.""" + + storage_at_device: bool = False + + def _compute_self_score(self, per_sample_gradient: torch.Tensor) -> None: + """Computes self-influence scores using per-sample gradients. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample gradient tensor for the given batch. + """ + if not self.storage_at_device: + move_storage_to_device( + storage=self.module.storage, + target_device=per_sample_gradient.device, + ) + self.storage_at_device = True + + preconditioned_gradient = ( + FactorConfig.CONFIGS[self.module.factor_args.strategy] + .precondition_gradient( + gradient=per_sample_gradient, + storage=self.module.storage, + ) + .to(dtype=self.module.score_args.score_dtype) + ) + per_sample_gradient = per_sample_gradient.to(dtype=self.module.score_args.score_dtype) + preconditioned_gradient.mul_(per_sample_gradient) + self.module.storage[SELF_SCORE_VECTOR_NAME] = preconditioned_gradient.sum(dim=(1, 2)) + + def register_hooks(self) -> None: + """Sets up hooks to compute self-influence scores.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append( + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).to(dtype=self.module.score_args.precondition_dtype) + self.clear_all_cache() + del output_gradient + self._compute_self_score(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Computes self-influence scores using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self.cached_per_sample_gradient = self.cached_per_sample_gradient.to( + dtype=self.module.score_args.precondition_dtype + ) + self._compute_self_score(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if self-influence score is available.""" + return self.module.storage[SELF_SCORE_VECTOR_NAME] is not None + + def accumulate_iterations(self) -> None: + """Removes self-influence scores from memory after a single iteration.""" + self.release_memory() + + def release_memory(self) -> None: + """Releases self-influence scores from memory.""" + self.clear_all_cache() + if self.storage_at_device: + move_storage_to_device(storage=self.module.storage, target_device=torch.device("cpu")) + self.storage_at_device = False + del self.module.storage[SELF_SCORE_VECTOR_NAME] + self.module.storage[SELF_SCORE_VECTOR_NAME] = None + + +class SelfScoreWithMeasurementTracker(BaseTracker): + """Computes self-influence scores with measurement for a given module.""" + + storage_at_device: bool = False + + def _compute_self_measurement_score_with_gradient(self, per_sample_gradient: torch.Tensor) -> None: + """Computes self-influence scores with measurement using per-sample-gradients. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample-gradient tensor for the given batch. + """ + scores = per_sample_gradient.mul_(self.module.storage[PRECONDITIONED_GRADIENT_NAME]).sum(dim=(1, 2)) + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + if self.module.storage[SELF_SCORE_VECTOR_NAME] is None: + self.module.storage[SELF_SCORE_VECTOR_NAME] = scores + else: + self.module.storage[SELF_SCORE_VECTOR_NAME].add_(scores) + + def register_hooks(self) -> None: + """Sets up hooks to compute self-influence scores with measurement.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append(outputs.register_hook(backward_hook)) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + if not self.storage_at_device: + move_storage_to_device( + storage=self.module.storage, + target_device=output_gradient.device, + ) + self.storage_at_device = True + + handle = self.cached_hooks.pop() + handle.remove() + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.score_dtype + ) + if isinstance(self.cached_activations, list): + cached_activation = self.cached_activations.pop() + else: + cached_activation = self.cached_activations + if self.module.per_sample_gradient_process_fnc is None: + scores = self.module.compute_self_measurement_score( + preconditioned_gradient=self.module.storage[PRECONDITIONED_GRADIENT_NAME], + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + self.clear_all_cache() + if self.module.storage[SELF_SCORE_VECTOR_NAME] is None: + self.module.storage[SELF_SCORE_VECTOR_NAME] = scores + else: + self.module.storage[SELF_SCORE_VECTOR_NAME].add_(scores) + else: + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + del cached_activation, output_gradient + self._compute_self_measurement_score_with_gradient(per_sample_gradient=per_sample_gradient) + + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) + + def finalize_iteration(self) -> None: + """Clears all cached data from memory.""" + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if self-influence score is available.""" + return self.module.storage[SELF_SCORE_VECTOR_NAME] is not None + + def accumulate_iterations(self) -> None: + """Removes self-influence scores from memory after a single iteration.""" + self.release_memory() + + def release_memory(self) -> None: + """Releases self-influence scores from memory.""" + self.clear_all_cache() + if self.storage_at_device: + move_storage_to_device(storage=self.module.storage, target_device=torch.device("cpu")) + self.storage_at_device = False + del self.module.storage[SELF_SCORE_VECTOR_NAME] + self.module.storage[SELF_SCORE_VECTOR_NAME] = None diff --git a/kronfluence/module/utils.py b/kronfluence/module/utils.py index 8fc6bf6..e858121 100644 --- a/kronfluence/module/utils.py +++ b/kronfluence/module/utils.py @@ -13,8 +13,18 @@ def _get_submodules(model: nn.Module, key: str) -> Tuple[nn.Module, str]: - """Returns the parent module and its name given the name of the current module.""" - # The code is modified from: https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py. + """Retrieves the parent module and its name given the name of the current module. + + Args: + model (nn.Module): + The PyTorch model to inspect. + key (str): + The full name of the current module. + + Returns: + Tuple[nn.Module, str]: + The parent module and the name of the target module. + """ parent = model.get_submodule(".".join(key.split(".")[:-1])) target_name = key.split(".")[-1] return parent, target_name @@ -34,13 +44,13 @@ def wrap_tracked_modules( task (Task): The specific task associated with the model. factor_args (FactorArguments, optional): - Arguments related to computing the influence factors. + Arguments related to computing influence factors. score_args (ScoreArguments, optional): - Arguments related to computing the influence scores. + Arguments related to computing influence scores. Returns: nn.Module: - The wrapped Pytorch model with `TrackedModule` installed. + The processed model with `TrackedModule` installed. """ if isinstance(model, (DP, DDP, FSDP)): raise ValueError( @@ -48,32 +58,34 @@ def wrap_tracked_modules( "or FullyShardedDataParallel. Call `prepare_model` before wrapping the model." ) - tracked_module_count = 0 - tracked_module_names = task.tracked_modules() if task is not None else None + tracked_module_names = task.get_influence_tracked_modules() if task is not None else None tracked_module_exists_dict = None if tracked_module_names is not None: tracked_module_exists_dict = {name: False for name in tracked_module_names} + per_sample_gradient_process_fnc = None + if task is not None and task.enable_post_process_per_sample_gradient: + per_sample_gradient_process_fnc = task.post_process_per_sample_gradient named_modules = model.named_modules() for module_name, module in named_modules: if len(list(module.children())) > 0: continue - # Filter modules based on the task's `influence_modules` if specified. + # Filters modules based on the task's `get_influence_tracked_modules` if specified. if tracked_module_names is not None and module_name not in tracked_module_names: continue - # Wrap the module if it is currently supported (e.g., nn.Linear & nn.Conv2d). + # Wraps the module if it is currently supported (e.g., nn.Linear & nn.Conv2d). if isinstance(module, tuple(TrackedModule.SUPPORTED_MODULES)): tracked_module = TrackedModule.SUPPORTED_MODULES[type(module)]( name=module_name, original_module=module, + per_sample_gradient_process_fnc=per_sample_gradient_process_fnc, factor_args=factor_args, score_args=score_args, ) parent, target_name = _get_submodules(model=model, key=module_name) setattr(parent, target_name, tracked_module) - tracked_module_count += 1 if tracked_module_exists_dict is not None: tracked_module_exists_dict[module_name] = True @@ -84,22 +96,34 @@ def wrap_tracked_modules( ) raise IllegalTaskConfigurationError(error_msg) - if tracked_module_count == 0: - supported_modules_names = [module.__name__ for module in TrackedModule.SUPPORTED_MODULES] - error_msg = ( - f"Kronfluence currently supports following PyTorch modules: `{supported_modules_names}`. " - f"However, these modules were not found in the provided model. If you want to analyze " - "custom layers, consider rewriting your model to use the supported modules, " - "or define your own custom module by subclassing `TrackedModule`." + if not any(isinstance(module, TrackedModule) for module in model.modules()): + supported_modules = ", ".join(module.__name__ for module in TrackedModule.SUPPORTED_MODULES) + raise IllegalTaskConfigurationError( + f"No supported modules found. Kronfluence supports: {supported_modules}. " + "Consider rewriting your model or subclassing `TrackedModule` for custom layers.\n" + f"Current Model:\n{model}" ) - error_msg += f"\nCurrent Model:\n{model}" - raise IllegalTaskConfigurationError(error_msg) - return model def make_modules_partition(total_module_names: List[str], partition_size: int) -> List[List[str]]: - """Divides a list of module names into smaller partitions of a specified size.""" + """Divides a list of module names into smaller partitions of a specified size. + + Args: + total_module_names (List[str]): + The list of all module names. + partition_size (int): + The number of partitions to create. + + Returns: + List[List[str]]: + A list of partitioned module names. + + Raises: + ValueError: If `len(total_module_names)` is less than `partition_size`. + """ + if len(total_module_names) < partition_size: + raise ValueError("The total modules must be equal to or greater than the partition size.") # See https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length. div, mod = divmod(len(total_module_names), partition_size) return list( @@ -107,132 +131,146 @@ def make_modules_partition(total_module_names: List[str], partition_size: int) - ) -def update_factor_args(model: nn.Module, factor_args: FactorArguments) -> None: - """Updates the factor arguments for all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.update_factor_args(factor_args=factor_args) - - -def update_score_args(model: nn.Module, score_args: ScoreArguments) -> None: - """Updates the score arguments for all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.update_score_args(score_args=score_args) - - -def get_tracked_module_names(model: nn.Module) -> List[str]: - """Returns the names of `TrackedModule` instances within a model.""" - tracked_modules = [] - for module in model.modules(): - if isinstance(module, TrackedModule): - tracked_modules.append(module.name) - return tracked_modules - +def set_mode( + model: nn.Module, + mode: ModuleMode, + tracked_module_names: List[str] = None, + release_memory: bool = False, +) -> None: + """Sets the module mode of specified `TrackedModule` instances within a model. -def synchronize_covariance_matrices(model: nn.Module) -> None: - """Synchronizes covariance matrices of all `TrackedModule` instances within a model.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + mode (ModuleMode): + The new mode to set for `TrackedModule`. + tracked_module_names (List[str], optional): + Names of modules to update. If `None`, updates all. + release_memory (bool, optional): + If `True`, releases memory of existing factors. + """ for module in model.modules(): if isinstance(module, TrackedModule): - module.synchronize_covariance_matrices() - + if tracked_module_names is not None and module.name not in tracked_module_names: + continue + module.set_mode(mode=mode, release_memory=release_memory) -def synchronize_lambda_matrices(model: nn.Module) -> None: - """Synchronizes Lambda matrices of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.synchronize_lambda_matrices() +def update_factor_args(model: nn.Module, factor_args: FactorArguments) -> None: + """Updates the factor arguments for all `TrackedModule` instances within a model. -def accumulate_preconditioned_gradient(model: nn.Module) -> None: - """Accumulates preconditioned gradient of all `TrackedModule` instances within a model.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + factor_args (FactorArguments): + The new factor arguments to set. + """ for module in model.modules(): if isinstance(module, TrackedModule): - module.accumulate_preconditioned_gradient() - + module.update_factor_args(factor_args=factor_args) -def release_preconditioned_gradient(model: nn.Module) -> None: - """Releases preconditioned gradient of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.release_preconditioned_gradient() +def update_score_args(model: nn.Module, score_args: ScoreArguments) -> None: + """Updates the score arguments for all `TrackedModule` instances within a model. -def truncate_preconditioned_gradient(model: nn.Module, keep_size: int) -> None: - """Truncates preconditioned gradient of all `TrackedModule` instances within a model.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + score_args (ScoreArguments): + The new score arguments to set. + """ for module in model.modules(): if isinstance(module, TrackedModule): - module.truncate_preconditioned_gradient(keep_size=keep_size) + module.update_score_args(score_args=score_args) -def synchronize_preconditioned_gradient(model: nn.Module, num_processes: int) -> None: - """Synchronizes preconditioned gradient of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.synchronize_preconditioned_gradient(num_processes=num_processes) +def get_tracked_module_names(model: nn.Module) -> List[str]: + """Returns the names of `TrackedModule` instances within a model. + Args: + model (nn.Module): + The PyTorch model to inspect. -def release_scores(model: nn.Module) -> None: - """Releases scores of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.release_scores() + Returns: + List[str]: + A list of names of `TrackedModule` instances. + """ + return [module.name for module in model.modules() if isinstance(module, TrackedModule)] -def set_mode( +def load_factors( model: nn.Module, - mode: ModuleMode, + factor_name: str, tracked_module_names: List[str] = None, - keep_factors: bool = False, -) -> None: - """Sets the module mode of all `TrackedModule` instances within a model. For example, to compute - and aggregate the covariance matrices, the module mode needs to be set to `ModuleMode.COVARIANCE`. + cpu: bool = True, +) -> Dict[str, torch.Tensor]: + """Loads factors with the given name from specified `TrackedModule` instances. Args: model (nn.Module): - The PyTorch model which contains `TrackedModule`. - mode (ModuleMode): - The new mode to set for `TrackedModule`. - tracked_module_names (List[str], optional): - The list of names for `TrackedModule` to set the new mode. If not provided, the new mode is - set for all available `TrackedModule`. - keep_factors (bool, optional): - If True, existing factors are kept in memory. Defaults to False. + The PyTorch model containing `TrackedModule` instances. + factor_name (str): + The name of the factor to load. + tracked_module_names (Optional[List[str]]): + Names of modules to load from. If `None`, loads from all. + cpu (bool): + If `True`, moves factors to CPU and releases GPU memory. + + Returns: + Dict[str, torch.Tensor]: + A dictionary of loaded factors, keyed by module name. """ + loaded_factors = {} for module in model.modules(): if isinstance(module, TrackedModule): if tracked_module_names is not None and module.name not in tracked_module_names: continue - module.set_mode(mode=mode, keep_factors=keep_factors) - - -def load_factors( - model: nn.Module, - factor_name: str, - clone: bool = False, -) -> Dict[str, torch.Tensor]: - """Loads factors with the given name from all `TrackedModule` instances within a model.""" - loaded_factors = {} - for module in model.modules(): - if isinstance(module, TrackedModule): factor = module.get_factor(factor_name=factor_name) if factor is not None: - loaded_factors[module.name] = factor.contiguous().clone() if clone else factor + if cpu: + loaded_factors[module.name] = factor.to(device="cpu", copy=True) + module.release_factor(factor_name=factor_name) + else: + loaded_factors[module.name] = factor return loaded_factors -def set_factors(model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor]) -> None: - """Sets new factor for all `TrackedModule` instances within a model.""" +def set_factors(model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor], clone: bool = False) -> None: + """Sets new factors for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + factor_name (str): + The name of the factor to set. + factors (Dict[str, torch.Tensor]): + A dictionary of factors to set, keyed by module name. + clone (bool): + If `True`, clones the factors before setting. + """ for module in model.modules(): if isinstance(module, TrackedModule): - module.set_factor(factor_name=factor_name, factor=factors[module.name]) + module.set_factor( + factor_name=factor_name, factor=factors[module.name].clone() if clone else factors[module.name] + ) def set_attention_mask( model: nn.Module, attention_mask: Optional[Union[Dict[str, torch.Tensor], torch.Tensor]] = None, ) -> None: - """Sets the attention mask for all `TrackedModule` instances within a model.""" + """Sets the attention mask for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + attention_mask (Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]): + The attention mask to set. Can be a dictionary (keyed by module name) or a single tensor. + + Raises: + RuntimeError: + If an invalid attention mask is provided. + """ for module in model.modules(): if isinstance(module, TrackedModule): if isinstance(attention_mask, dict): @@ -242,6 +280,8 @@ def set_attention_mask( module.set_attention_mask(attention_mask=None) elif isinstance(attention_mask, torch.Tensor): module.set_attention_mask(attention_mask=attention_mask) + elif attention_mask is None: + module.set_attention_mask(attention_mask=None) else: raise RuntimeError(f"Invalid attention mask `{attention_mask}` provided.") @@ -250,42 +290,124 @@ def set_gradient_scale( model: nn.Module, gradient_scale: float = 1.0, ) -> None: - """Sets the gradient scale for all `TrackedModule` instances within a model.""" + """Sets the gradient scale for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + gradient_scale (float): + The gradient scale to set. + """ for module in model.modules(): if isinstance(module, TrackedModule): module.set_gradient_scale(scale=gradient_scale) -def finalize_lambda_matrices(model: nn.Module) -> None: - """Updates Lambda matrices of all `TrackedModule` instances within a model.""" +def prepare_modules(model: nn.Module, tracked_module_names: List[str], device: torch.device) -> None: + """Prepares specified `TrackedModule` instances for score computation. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to prepare. + device (torch.device): + The device to prepare the modules for. + """ for module in model.modules(): - if isinstance(module, TrackedModule): - module.finalize_lambda_matrix() + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.prepare_storage(device=device) -def finalize_preconditioned_gradient(model: nn.Module) -> None: - """Computes preconditioned gradient of all `TrackedModule` instances within a model.""" +def synchronize_modules(model: nn.Module, tracked_module_names: List[str], num_processes: int = 1) -> None: + """Synchronizes specified `TrackedModule` instances across processes. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to synchronize. + num_processes (int): + The number of processes to synchronize across. + """ for module in model.modules(): - if isinstance(module, TrackedModule): - module.finalize_preconditioned_gradient() + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.synchronize(num_processes=num_processes) -def finalize_pairwise_scores(model: nn.Module) -> None: - """Computes pairwise influence scores of all `TrackedModule` instances within a model.""" +def truncate(model: nn.Module, tracked_module_names: List[str], keep_size: int) -> None: + """Truncates the data in specified `TrackedModule` instances. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to truncate. + keep_size (int): + The number of elements to keep after truncation. + """ for module in model.modules(): - if isinstance(module, TrackedModule): - module.finalize_pairwise_score() + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.truncate(keep_size=keep_size) + + +def exist_for_all_modules(model: nn.Module, tracked_module_names: List[str]) -> bool: + """Checks if all specified `TrackedModule` instances have existing factor. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to check. + + Returns: + bool: + `True` if all specified modules have existing factor, `False` otherwise. + """ + return all( + module.exist() + for module in model.modules() + if isinstance(module, TrackedModule) and module.name in tracked_module_names + ) + +def accumulate_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: + """Accumulates iterations for specified `TrackedModule` instances. -def finalize_self_scores(model: nn.Module) -> None: - """Computes self-influence scores of all `TrackedModule` instances within a model.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to accumulate iterations for. + """ for module in model.modules(): - if isinstance(module, TrackedModule): - module.finalize_self_score() + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.accumulate_iterations() + + +def finalize_iteration(model: nn.Module, tracked_module_names: List[str]) -> None: + """Finalizes the current iteration for specified `TrackedModule` instances. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to finalize iteration for. + """ + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.finalize_iteration() + +def finalize_all_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: + """Finalizes all iterations for specified `TrackedModule` instances. -def finalize_self_measurement_scores(model: nn.Module) -> None: - """Computes self-influence scores with measurement of all `TrackedModule` instances within a model.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to finalize all iterations for. + """ for module in model.modules(): - if isinstance(module, TrackedModule): - module.finalize_self_measurement_score() + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.finalize_all_iterations() diff --git a/kronfluence/score/__init__.py b/kronfluence/score/__init__.py index 49a4267..41ac691 100644 --- a/kronfluence/score/__init__.py +++ b/kronfluence/score/__init__.py @@ -1,5 +1,5 @@ from .pairwise import ( - _compute_dot_products_with_loader, + compute_dot_products_with_loader, compute_pairwise_scores_with_loaders, load_pairwise_scores, pairwise_scores_exist, diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py new file mode 100644 index 0000000..cd96f3e --- /dev/null +++ b/kronfluence/score/dot_product.py @@ -0,0 +1,257 @@ +from typing import Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from accelerate.utils import send_to_device +from torch import autocast, nn +from torch.cuda.amp import GradScaler +from torch.utils import data +from tqdm import tqdm + +from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.module import TrackedModule +from kronfluence.module.tracked_module import ModuleMode +from kronfluence.module.utils import ( + accumulate_iterations, + exist_for_all_modules, + finalize_all_iterations, + finalize_iteration, + set_mode, + synchronize_modules, +) +from kronfluence.task import Task +from kronfluence.utils.constants import ( + ALL_MODULE_NAME, + DISTRIBUTED_SYNC_INTERVAL, + PAIRWISE_SCORE_MATRIX_NAME, + SCORE_TYPE, +) +from kronfluence.utils.logger import TQDM_BAR_FORMAT +from kronfluence.utils.state import State, no_sync, release_memory + +DIMENSION_NOT_MATCH_ERROR_MSG = ( + "The model does not support token-wise score computation. " + "Set `compute_per_module_scores=True` or `compute_per_token_scores=False` " + "to avoid this error." +) + + +def compute_dot_products_with_loader( + model: nn.Module, + task: Task, + state: State, + train_loader: data.DataLoader, + factor_args: FactorArguments, + score_args: ScoreArguments, + tracked_module_names: List[str], + scaler: GradScaler, + disable_tqdm: bool = False, +) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + """After computing the preconditioned query gradient, compute dot products with individual training gradients.""" + model.zero_grad(set_to_none=True) + set_mode( + model=model, + mode=ModuleMode.PAIRWISE_SCORE, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + + dataset_size = len(train_loader.dataset) + score_chunks: Dict[str, List[torch.Tensor]] = {} + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name] = [] + else: + score_chunks[ALL_MODULE_NAME] = [] + + total_steps = 0 + enable_amp = score_args.amp_dtype is not None + + with tqdm( + total=len(train_loader), + desc="Computing pairwise scores (training gradient)", + bar_format=TQDM_BAR_FORMAT, + disable=not state.is_main_process or disable_tqdm, + ) as pbar: + for batch in train_loader: + batch = send_to_device(tensor=batch, device=state.device) + + with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) + with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): + loss = task.compute_train_loss( + batch=batch, + model=model, + sample=False, + ) + scaler.scale(loss).backward() + + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + with torch.no_grad(): + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to(device="cpu", copy=True) + ) + else: + pairwise_scores = None + for module in cached_module_lst: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError as exc: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) from exc + raise + pairwise_scores = pairwise_scores.cpu() + score_chunks[ALL_MODULE_NAME].append(pairwise_scores) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: + state.wait_for_everyone() + + del loss + total_steps += 1 + pbar.update(1) + + model.zero_grad(set_to_none=True) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + total_scores: SCORE_TYPE = {} + for module_name, chunks in score_chunks.items(): + total_scores[module_name] = torch.cat(chunks, dim=1) + if state.use_distributed: + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + dist.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=1)[:, :dataset_size].cpu() + else: + total_scores[module_name] = total_scores[module_name].cpu() + state.wait_for_everyone() + + return total_scores + + +def compute_aggregated_dot_products_with_loader( + model: nn.Module, + task: Task, + state: State, + train_loader: data.DataLoader, + factor_args: FactorArguments, + score_args: ScoreArguments, + tracked_module_names: List[str], + scaler: GradScaler, + disable_tqdm: bool = False, +) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + """After computing the preconditioned query gradient, compute dot products with aggregated training gradients.""" + model.zero_grad(set_to_none=True) + set_mode( + model=model, + mode=ModuleMode.GRADIENT_AGGREGATION, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + scores: Dict[str, Optional[torch.Tensor]] = {} + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + scores[module.name] = None + else: + scores[ALL_MODULE_NAME] = None + + enable_amp = score_args.amp_dtype is not None + + if not exist_for_all_modules(model=model, tracked_module_names=tracked_module_names): + with tqdm( + total=len(train_loader), + desc="Computing pairwise scores (training gradient)", + bar_format=TQDM_BAR_FORMAT, + disable=not state.is_main_process or disable_tqdm, + ) as pbar: + for batch in train_loader: + batch = send_to_device(tensor=batch, device=state.device) + + with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) + with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): + loss = task.compute_train_loss( + batch=batch, + model=model, + sample=False, + ) + scaler.scale(loss).backward() + + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + del loss + pbar.update(1) + + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + + set_mode( + model=model, + mode=ModuleMode.PAIRWISE_SCORE, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + + with torch.no_grad(): + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + scores[module.name] = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to( + device="cpu", copy=True + ) + else: + pairwise_scores = None + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError as exc: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) from exc + raise + scores[ALL_MODULE_NAME] = pairwise_scores.cpu() + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + model.zero_grad(set_to_none=True) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + return scores diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index bb3b2a0..abc4a1f 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -10,40 +10,47 @@ from tqdm import tqdm from kronfluence.arguments import FactorArguments, ScoreArguments -from kronfluence.module import TrackedModule from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( - accumulate_preconditioned_gradient, - finalize_pairwise_scores, - finalize_preconditioned_gradient, + accumulate_iterations, + finalize_all_iterations, + finalize_iteration, get_tracked_module_names, - release_preconditioned_gradient, - release_scores, + prepare_modules, set_factors, set_gradient_scale, set_mode, - synchronize_preconditioned_gradient, - truncate_preconditioned_gradient, + synchronize_modules, + truncate, update_factor_args, update_score_args, ) -from kronfluence.task import Task -from kronfluence.utils.constants import ( - ALL_MODULE_NAME, - FACTOR_TYPE, - PAIRWISE_SCORE_MATRIX_NAME, - PARTITION_TYPE, - SCORE_TYPE, +from kronfluence.score.dot_product import ( + compute_aggregated_dot_products_with_loader, + compute_dot_products_with_loader, ) +from kronfluence.task import Task +from kronfluence.utils.constants import FACTOR_TYPE, PARTITION_TYPE, SCORE_TYPE from kronfluence.utils.logger import TQDM_BAR_FORMAT -from kronfluence.utils.state import State, no_sync, release_memory +from kronfluence.utils.state import State, no_sync def pairwise_scores_save_path( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading pairwise scores.""" + """Generates the path for saving or loading pairwise influence scores. + + Args: + output_dir (Path): + Directory to save or load the matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the score file. + """ if partition is not None: data_partition, module_partition = partition return output_dir / ( @@ -52,25 +59,24 @@ def pairwise_scores_save_path( return output_dir / "pairwise_scores.safetensors" -def pairwise_scores_exist( - output_dir: Path, - partition: Optional[PARTITION_TYPE] = None, -) -> bool: - """Checks if the pairwise scores exist at the specified path.""" - save_path = pairwise_scores_save_path( - output_dir=output_dir, - partition=partition, - ) - return save_path.exists() - - def save_pairwise_scores( output_dir: Path, scores: SCORE_TYPE, partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves pairwise influence scores to disk.""" + """Saves pairwise scores to disk. + + Args: + output_dir (Path): + Directory to save the scores. + scores (SCORE_TYPE): + Dictionary of scores to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the scores. + """ save_path = pairwise_scores_save_path( output_dir=output_dir, partition=partition, @@ -81,8 +87,19 @@ def save_pairwise_scores( def load_pairwise_scores( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, -) -> Dict[str, torch.Tensor]: - """Loads pairwise scores from disk.""" +) -> SCORE_TYPE: + """Loads pairwise scores from disk. + + Args: + output_dir (Path): + Directory to load the scores from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + SCORE_TYPE: + Dictionary of loaded scores. + """ save_path = pairwise_scores_save_path( output_dir=output_dir, partition=partition, @@ -90,113 +107,27 @@ def load_pairwise_scores( return load_file(filename=save_path) -def _compute_dot_products_with_loader( - model: nn.Module, - task: Task, - state: State, - train_loader: data.DataLoader, - factor_args: FactorArguments, - score_args: ScoreArguments, - tracked_module_names: List[str], - scaler: GradScaler, - disable_tqdm: bool = False, -) -> Union[Dict[str, torch.Tensor], torch.Tensor]: - """After computing the preconditioned query gradient, compute dot products with training gradients.""" - with torch.no_grad(): - model.zero_grad(set_to_none=True) - set_mode( - model=model, - mode=ModuleMode.PAIRWISE_SCORE, - tracked_module_names=tracked_module_names, - keep_factors=True, - ) - release_memory() - - dataset_size = len(train_loader.dataset) - score_chunks: Dict[str, List[torch.Tensor]] = {} - if score_args.per_module_score: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name] = [] - else: - score_chunks[ALL_MODULE_NAME] = [] - - total_steps = 0 - enable_amp = score_args.amp_dtype is not None - - with tqdm( - total=len(train_loader), - desc="Computing pairwise scores (training gradient)", - bar_format=TQDM_BAR_FORMAT, - disable=not state.is_main_process or disable_tqdm, - ) as pbar: - for batch in train_loader: - batch = send_to_device(tensor=batch, device=state.device) - - model.zero_grad(set_to_none=True) - with no_sync(model=model, state=state): - with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): - loss = task.compute_train_loss( - batch=batch, - model=model, - sample=False, - ) - scaler.scale(loss).backward() - - if factor_args.shared_parameters_exist: - finalize_pairwise_scores(model=model) - - with torch.no_grad(): - if score_args.per_module_score: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).cpu() - ) - else: - # Aggregates the pairwise scores across all modules. - pairwise_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if pairwise_scores is None: - pairwise_scores = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone() - else: - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) - score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) - release_scores(model=model) - - if state.use_distributed and total_steps % score_args.distributed_sync_steps == 0: - # Periodically synchronizes all processes to avoid timeout at the final synchronization. - state.wait_for_everyone() - - total_steps += 1 - pbar.update(1) - - with torch.no_grad(): - model.zero_grad(set_to_none=True) - set_mode( - model=model, - mode=ModuleMode.PRECONDITION_GRADIENT, - tracked_module_names=tracked_module_names, - keep_factors=True, - ) - release_preconditioned_gradient(model=model) - release_memory() +def pairwise_scores_exist( + output_dir: Path, + partition: Optional[PARTITION_TYPE] = None, +) -> bool: + """Checks if pairwise influence scores exist at the specified directory. - total_scores: SCORE_TYPE = {} - for module_name, chunks in score_chunks.items(): - total_scores[module_name] = torch.cat(chunks, dim=1) - if state.use_distributed: - total_scores[module_name] = total_scores[module_name].to(device=state.device) - gather_list = None - if state.is_main_process: - gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) - if state.is_main_process: - total_scores[module_name] = torch.cat(gather_list, dim=1)[:, :dataset_size].cpu() - state.wait_for_everyone() + Args: + output_dir (Path): + Directory to check for scores. + partition (PARTITION_TYPE, optional): + Partition information, if any. - return total_scores + Returns: + bool: + `True` if scores exist, `False` otherwise. + """ + save_path = pairwise_scores_save_path( + output_dir=output_dir, + partition=partition, + ) + return save_path.exists() def compute_pairwise_scores_with_loaders( @@ -216,7 +147,7 @@ def compute_pairwise_scores_with_loaders( Args: loaded_factors (FACTOR_TYPE): - The factor results to load from, before computing the pairwise scores. + Computed factors. model (nn.Module): The model for which pairwise influence scores will be computed. state (State): @@ -230,38 +161,38 @@ def compute_pairwise_scores_with_loaders( train_loader (data.DataLoader): The data loader that will be used to compute training gradients. score_args (ScoreArguments): - Arguments related to computing pairwise scores. + Arguments for computing pairwise scores. factor_args (FactorArguments): - Arguments related to computing preconditioning factors. + Arguments used to compute factors. tracked_module_names (List[str], optional): A list of module names that pairwise scores will be computed. If not specified, scores will be computed for all available tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: - Dict[str, torch.Tensor]: + SCORE_TYPE: A dictionary containing the module name and its pairwise influence scores. """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - mode=ModuleMode.PRECONDITION_GRADIENT, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - # Loads necessary factors before computing pairwise influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors( + model=model, + factor_name=name, + factors=loaded_factors[name], + clone=True, + ) + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) total_scores_chunks: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {} total_query_batch_size = per_device_query_batch_size * state.num_processes @@ -276,6 +207,12 @@ def compute_pairwise_scores_with_loaders( gradient_scale = 1.0 / scaler.get_scale() set_gradient_scale(model=model, gradient_scale=gradient_scale) + dot_product_func = ( + compute_aggregated_dot_products_with_loader + if score_args.aggregate_train_gradients + else compute_dot_products_with_loader + ) + with tqdm( total=num_batches, desc="Computing pairwise scores (query gradient)", @@ -289,31 +226,36 @@ def compute_pairwise_scores_with_loaders( device=state.device, ) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): measurement = task.compute_measurement(batch=query_batch, model=model) scaler.scale(measurement).backward() - if factor_args.shared_parameters_exist: - finalize_preconditioned_gradient(model=model) + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) if state.use_distributed: - # Stacks preconditioned query gradient across multiple devices or nodes. - synchronize_preconditioned_gradient(model=model, num_processes=state.num_processes) + # Stack preconditioned query gradient across multiple devices or nodes. + synchronize_modules( + model=model, tracked_module_names=tracked_module_names, num_processes=state.num_processes + ) if query_index == len(query_loader) - 1 and query_remainder > 0: - # Removes duplicate data points if the dataset is not exactly divisible - # by the current batch size. - truncate_preconditioned_gradient(model=model, keep_size=query_remainder) + # Removes duplicate data points if the dataset is not evenly divisible by the current batch size. + truncate(model=model, tracked_module_names=tracked_module_names, keep_size=query_remainder) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + del query_batch, measurement - accumulate_preconditioned_gradient(model=model) num_accumulations += 1 - if num_accumulations < score_args.num_query_gradient_accumulations and query_index != len(query_loader) - 1: + if ( + num_accumulations < score_args.query_gradient_accumulation_steps + and query_index != len(query_loader) - 1 + ): pbar.update(1) continue # Computes the dot product between preconditioning query gradient and all training gradients. - scores = _compute_dot_products_with_loader( + scores = dot_product_func( model=model, state=state, task=task, @@ -325,27 +267,125 @@ def compute_pairwise_scores_with_loaders( disable_tqdm=disable_tqdm, ) - with torch.no_grad(): - if state.is_main_process: - for module_name, current_scores in scores.items(): - if module_name not in total_scores_chunks: - total_scores_chunks[module_name] = [] - total_scores_chunks[module_name].append(current_scores) - state.wait_for_everyone() + if state.is_main_process: + for module_name, current_scores in scores.items(): + if module_name not in total_scores_chunks: + total_scores_chunks[module_name] = [] + total_scores_chunks[module_name].append(current_scores) + del scores + state.wait_for_everyone() num_accumulations = 0 pbar.update(1) - with torch.no_grad(): - if state.is_main_process: - for module_name in total_scores_chunks: - total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) - state.wait_for_everyone() + if state.is_main_process: + for module_name in total_scores_chunks: + total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return total_scores_chunks + + +def compute_pairwise_query_aggregated_scores_with_loaders( + loaded_factors: FACTOR_TYPE, + model: nn.Module, + state: State, + task: Task, + query_loader: data.DataLoader, + per_device_query_batch_size: int, + train_loader: data.DataLoader, + score_args: ScoreArguments, + factor_args: FactorArguments, + tracked_module_names: Optional[List[str]], + disable_tqdm: bool = False, +) -> Dict[str, torch.Tensor]: + """Computes pairwise influence scores (with query gradients aggregated) for a given model and task. See + `compute_pairwise_scores_with_loaders` for detailed information.""" + del per_device_query_batch_size + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + mode=ModuleMode.GRADIENT_AGGREGATION, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) + + enable_amp = score_args.amp_dtype is not None + scaler = GradScaler(enabled=enable_amp) + if enable_amp: + gradient_scale = 1.0 / scaler.get_scale() + set_gradient_scale(model=model, gradient_scale=gradient_scale) + + dot_product_func = ( + compute_aggregated_dot_products_with_loader + if score_args.aggregate_train_gradients + else compute_dot_products_with_loader + ) + + with tqdm( + total=len(query_loader), + desc="Computing pairwise scores (query gradient)", + bar_format=TQDM_BAR_FORMAT, + disable=not state.is_main_process or disable_tqdm, + ) as pbar: + for query_batch in query_loader: + query_batch = send_to_device( + tensor=query_batch, + device=state.device, + ) + + with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) + with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): + measurement = task.compute_measurement(batch=query_batch, model=model) + scaler.scale(measurement).backward() + + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + del measurement + pbar.update(1) + + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + + scores = dot_product_func( + model=model, + state=state, + task=task, + train_loader=train_loader, + factor_args=factor_args, + score_args=score_args, + tracked_module_names=tracked_module_names, + scaler=scaler, + disable_tqdm=disable_tqdm, + ) + + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() + + return scores diff --git a/kronfluence/score/self.py b/kronfluence/score/self.py index cbf8139..ff97d03 100644 --- a/kronfluence/score/self.py +++ b/kronfluence/score/self.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional import torch +import torch.distributed as dist from accelerate.utils import send_to_device from safetensors.torch import load_file, save_file from torch import autocast, nn @@ -13,11 +14,10 @@ from kronfluence.module import TrackedModule from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( - finalize_preconditioned_gradient, - finalize_self_measurement_scores, - finalize_self_scores, + accumulate_iterations, + finalize_iteration, get_tracked_module_names, - release_scores, + prepare_modules, set_factors, set_gradient_scale, set_mode, @@ -27,6 +27,7 @@ from kronfluence.task import Task from kronfluence.utils.constants import ( ALL_MODULE_NAME, + DISTRIBUTED_SYNC_INTERVAL, FACTOR_TYPE, PARTITION_TYPE, SCORE_TYPE, @@ -40,7 +41,18 @@ def self_scores_save_path( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading self-influence scores.""" + """Generates the path for saving or loading self-influence scores. + + Args: + output_dir (Path): + Directory to save or load the matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the score file. + """ if partition is not None: data_partition, module_partition = partition return output_dir / ( @@ -49,42 +61,75 @@ def self_scores_save_path( return output_dir / "self_scores.safetensors" -def self_scores_exist( +def save_self_scores( output_dir: Path, + scores: SCORE_TYPE, partition: Optional[PARTITION_TYPE] = None, -) -> bool: - """Checks if the self-influence scores exist at the specified path.""" + metadata: Optional[Dict[str, str]] = None, +) -> None: + """Saves self-influence scores to disk. + + Args: + output_dir (Path): + Directory to save the scores. + scores (SCORE_TYPE): + Dictionary of scores to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the scores. + """ save_path = self_scores_save_path( output_dir=output_dir, partition=partition, ) - return save_path.exists() + save_file(tensors=scores, filename=save_path, metadata=metadata) -def save_self_scores( +def load_self_scores( output_dir: Path, - scores: SCORE_TYPE, partition: Optional[PARTITION_TYPE] = None, - metadata: Optional[Dict[str, str]] = None, -) -> None: - """Saves self-influence scores to disk.""" +) -> SCORE_TYPE: + """Loads self-influence scores from disk. + + Args: + output_dir (Path): + Directory to load the scores from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + SCORE_TYPE: + Dictionary of loaded scores. + """ save_path = self_scores_save_path( output_dir=output_dir, partition=partition, ) - save_file(tensors=scores, filename=save_path, metadata=metadata) + return load_file(filename=save_path) -def load_self_scores( +def self_scores_exist( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, -) -> Dict[str, torch.Tensor]: - """Loads self-influence scores from disk.""" +) -> bool: + """Checks if self-influence scores exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for scores. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + bool: + `True` if scores exist, `False` otherwise. + """ save_path = self_scores_save_path( output_dir=output_dir, partition=partition, ) - return load_file(filename=save_path) + return save_path.exists() def compute_self_scores_with_loaders( @@ -102,7 +147,7 @@ def compute_self_scores_with_loaders( Args: loaded_factors (FACTOR_TYPE): - The factor results to load from, before computing the self-influence scores. + Computed factors. model (nn.Module): The model for which self-influence scores will be computed. state (State): @@ -112,48 +157,48 @@ def compute_self_scores_with_loaders( train_loader (data.DataLoader): The data loader that will be used to compute training gradients. score_args (ScoreArguments): - Arguments related to computing self-influence scores. + Arguments for computing self-influence scores. factor_args (FactorArguments): - Arguments related to computing preconditioning factors. + Arguments used to compute factors. tracked_module_names (List[str], optional): A list of module names that self-influence scores will be computed. If not specified, scores will be computed for all available tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: - Dict[str, torch.Tensor]: + SCORE_TYPE: A dictionary containing the module name and its self-influence scores. """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - mode=ModuleMode.SELF_SCORE, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - # Loads necessary factors before computing self-influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + mode=ModuleMode.SELF_SCORE, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) dataset_size = len(train_loader.dataset) score_chunks: Dict[str, List[torch.Tensor]] = {} - if score_args.per_module_score: + if score_args.compute_per_module_scores: for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: score_chunks[module.name] = [] else: score_chunks[ALL_MODULE_NAME] = [] + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + total_steps = 0 enable_amp = score_args.amp_dtype is not None scaler = GradScaler(enabled=enable_amp) @@ -167,14 +212,14 @@ def compute_self_scores_with_loaders( bar_format=TQDM_BAR_FORMAT, disable=not state.is_main_process or disable_tqdm, ) as pbar: - for batch in train_loader: + for index, batch in enumerate(train_loader): batch = send_to_device( tensor=batch, device=state.device, ) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): loss = task.compute_train_loss( batch=batch, @@ -183,59 +228,63 @@ def compute_self_scores_with_loaders( ) scaler.scale(loss).backward() - if factor_args.shared_parameters_exist: - finalize_self_scores(model=model) + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) with torch.no_grad(): - if score_args.per_module_score: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).cpu() - ) + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).to(device="cpu", copy=True) + ) else: - # Aggregates the self-influence scores across all modules. self_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if self_scores is None: - self_scores = module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone() - else: - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) - score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) - release_scores(model=model) - - if state.use_distributed and total_steps % score_args.distributed_sync_steps == 0: - # Periodically synchronizes all processes to avoid timeout at the final synchronization. + for module in cached_module_lst: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + ) + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + self_scores = self_scores.cpu() + score_chunks[ALL_MODULE_NAME].append(self_scores) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if ( + state.use_distributed + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 + and index not in [len(train_loader) - 1, len(train_loader) - 2] + ): state.wait_for_everyone() + del loss total_steps += 1 pbar.update(1) - with torch.no_grad(): - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode( - model=model, - mode=ModuleMode.DEFAULT, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - release_memory() - - total_scores: SCORE_TYPE = {} - for module_name, chunks in score_chunks.items(): - total_scores[module_name] = torch.cat(chunks, dim=0) - if state.use_distributed: - total_scores[module_name] = total_scores[module_name].to(device=state.device) - gather_list = None - if state.is_main_process: - gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) - if state.is_main_process: - total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() - state.wait_for_everyone() + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode( + model=model, + mode=ModuleMode.DEFAULT, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + release_memory() + + total_scores: SCORE_TYPE = {} + for module_name, chunks in score_chunks.items(): + total_scores[module_name] = torch.cat(chunks, dim=0) + if state.use_distributed: + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + dist.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() + else: + total_scores[module_name] = total_scores[module_name].cpu() + state.wait_for_everyone() return total_scores @@ -251,53 +300,32 @@ def compute_self_measurement_scores_with_loaders( tracked_module_names: Optional[List[str]], disable_tqdm: bool = False, ) -> Dict[str, torch.Tensor]: - """Computes self-influence scores with measurement (instead of the loss) for a given model and task. - - Args: - loaded_factors (FACTOR_TYPE): - The factor results to load from, before computing the self-influence scores. - model (nn.Module): - The model for which self-influence scores will be computed. - state (State): - The current process's information (e.g., device being used). - task (Task): - The specific task associated with the model. - train_loader (data.DataLoader): - The data loader that will be used to compute training gradients. - score_args (ScoreArguments): - Arguments related to computing self-influence scores. - factor_args (FactorArguments): - Arguments related to computing preconditioning factors. - tracked_module_names (List[str], optional): - A list of module names that self-influence scores will be computed. If not specified, scores - will be computed for all available tracked modules. - disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + """Computes self-influence scores with measurement (instead of the loss) for a given model and task. See + `compute_self_scores_with_loaders` for the detailed docstring.""" + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors( + model=model, + factor_name=name, + factors=loaded_factors[name], + clone=True, + ) + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) - Returns: - Dict[str, torch.Tensor]: - A dictionary containing the module name and its self-influence scores. - """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - # Loads necessary factors before computing self-influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) dataset_size = len(train_loader.dataset) score_chunks: Dict[str, List[torch.Tensor]] = {} - if score_args.per_module_score: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name] = [] + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name] = [] else: score_chunks[ALL_MODULE_NAME] = [] @@ -314,7 +342,7 @@ def compute_self_measurement_scores_with_loaders( bar_format=TQDM_BAR_FORMAT, disable=not state.is_main_process or disable_tqdm, ) as pbar: - for batch in train_loader: + for index, batch in enumerate(train_loader): batch = send_to_device( tensor=batch, device=state.device, @@ -324,25 +352,26 @@ def compute_self_measurement_scores_with_loaders( model=model, mode=ModuleMode.PRECONDITION_GRADIENT, tracked_module_names=tracked_module_names, - keep_factors=True, + release_memory=False, ) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): measurement = task.compute_measurement(batch=batch, model=model) scaler.scale(measurement).backward() - if factor_args.shared_parameters_exist: - finalize_preconditioned_gradient(model=model) + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + del measurement set_mode( model=model, mode=ModuleMode.SELF_MEASUREMENT_SCORE, tracked_module_names=tracked_module_names, - keep_factors=True, + release_memory=False, ) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): loss = task.compute_train_loss( batch=batch, @@ -351,58 +380,62 @@ def compute_self_measurement_scores_with_loaders( ) scaler.scale(loss).backward() - if factor_args.shared_parameters_exist: - finalize_self_measurement_scores(model=model) + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + del loss with torch.no_grad(): - if score_args.per_module_score: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).cpu() - ) + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).to(device="cpu", copy=True) + ) else: - # Aggregates the self-influence scores across all modules. self_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if self_scores is None: - self_scores = module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone() - else: - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) - score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) - release_scores(model=model) - - if state.use_distributed and total_steps % score_args.distributed_sync_steps == 0: - # Periodically synchronizes all processes to avoid timeout at the final synchronization. + for module in cached_module_lst: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + ) + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + self_scores = self_scores.cpu() + score_chunks[ALL_MODULE_NAME].append(self_scores) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if ( + state.use_distributed + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 + and index not in [len(train_loader) - 1, len(train_loader) - 2] + ): state.wait_for_everyone() total_steps += 1 pbar.update(1) - with torch.no_grad(): - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode( - model=model, - mode=ModuleMode.DEFAULT, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - release_memory() - - total_scores: SCORE_TYPE = {} - for module_name, chunks in score_chunks.items(): - total_scores[module_name] = torch.cat(chunks, dim=0) - if state.use_distributed: - total_scores[module_name] = total_scores[module_name].to(device=state.device) - gather_list = None - if state.is_main_process: - gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) - if state.is_main_process: - total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() - state.wait_for_everyone() + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode( + model=model, + mode=ModuleMode.DEFAULT, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + release_memory() + + total_scores: SCORE_TYPE = {} + for module_name, chunks in score_chunks.items(): + total_scores[module_name] = torch.cat(chunks, dim=0) + if state.use_distributed: + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + dist.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() + else: + total_scores[module_name] = total_scores[module_name].cpu() + state.wait_for_everyone() return total_scores diff --git a/kronfluence/task.py b/kronfluence/task.py index 55c22a0..1c82dd5 100644 --- a/kronfluence/task.py +++ b/kronfluence/task.py @@ -9,9 +9,15 @@ class Task(ABC): """Abstract base class for task definitions. Extend this class to implement specific tasks (e.g., regression, classification, language modeling) - with custom pipelines (models, data loaders, training objectives). + with custom pipelines (e.g., models, data loaders, training objectives). + + Attributes: + enable_post_process_per_sample_gradient (bool): + Flag to enable post-processing of per-sample gradients. Defaults to `False`. """ + enable_post_process_per_sample_gradient: bool = False + @abstractmethod def compute_train_loss( self, @@ -19,23 +25,23 @@ def compute_train_loss( model: nn.Module, sample: bool = False, ) -> torch.Tensor: - """Computes training loss for a given batch and model. + """Computes the training loss for a given batch and model. Args: batch (Any): - Batch of data sourced from the DataLoader. + A batch of data from the DataLoader. model (nn.Module): - The PyTorch model for loss computation. + The PyTorch model used for loss computation. sample (bool): Indicates whether to sample from the model's outputs or to use the actual targets from the - batch. Defaults to False. The case where `sample` is set to True must be implemented to + batch. Defaults to `False`. The case where `sample=True` must be implemented to approximate the true Fisher. Returns: torch.Tensor: - The computed loss as a tensor. + The computed loss as a scalar tensor. """ - raise NotImplementedError("Subclasses must implement the `compute_train_loss` method.") + raise NotImplementedError(f"{self.__class__.__name__} must implement the `compute_train_loss` method.") @abstractmethod def compute_measurement( @@ -43,49 +49,68 @@ def compute_measurement( batch: Any, model: nn.Module, ) -> torch.Tensor: - """Computes a measurable quantity (e.g., loss, logit, log probability) for a given batch and model. - This is defined as f(θ) from https://arxiv.org/pdf/2308.03296.pdf. + """Computes a measurable quantity for a given batch and model. + + This method calculates f(θ) as defined in https://arxiv.org/pdf/2308.03296.pdf. The measurable quantity + can be a loss, logit, log probability, or any other relevant metric for the task. Args: batch (Any): - Batch of data sourced from the DataLoader. + A batch of data from the DataLoader. model (nn.Module): - The PyTorch model for measurement computation. + The PyTorch model used for measurement computation. Returns: torch.Tensor: - The measurable quantity as a tensor. + The computed measurable quantity as a tensor. """ - raise NotImplementedError("Subclasses must implement the `compute_measurement` method.") + raise NotImplementedError(f"{self.__class__.__name__} must implement the `compute_measurement` method.") - def tracked_modules(self) -> Optional[List[str]]: - """Specifies modules for influence score computations. + def get_influence_tracked_modules(self) -> Optional[List[str]]: + """Specifies which modules should be tracked for influence factor and score computations. - Returns None by default, applying computations to all supported modules (e.g., nn.Linear, nn.Conv2d). - Subclasses can override this method to return a list of specific module names if influence functions + Override this method in subclasses to return a list of specific module names if influence functions should only be computed for a subset of the model. Returns: Optional[List[str]]: - A list of module names for which to compute influence functions, or None to indicate that - influence functions should be computed for all applicable modules. + A list of module names to compute influence functions for, or `None` to compute for + all applicable modules (e.g., `nn.Linear` and `nn.Conv2d`). """ def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: - """Returns masks for data points within a batch that have been padded extra tokens to ensure - consistent length across the batch. Typically, it returns None for models or datasets not requiring - masking. + """Gets attention masks for padded sequences in a batch. - See https://huggingface.co/docs/transformers/en/glossary#attention-mask. + This method is typically used for models or datasets that require masking, such as transformer-based + architectures. For more information, see: https://huggingface.co/docs/transformers/en/glossary#attention-mask. Args: batch (Any): - Batch of data sourced from the DataLoader. + A batch of data from the DataLoader. Returns: Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: - A binary tensor as the mask for the batch, or None if padding is not used. The mask dimensions should - match `batch_size x num_seq`. For models requiring different masks for different modules - (e.g., encoder-decoder architectures), returns a dictionary mapping module names to their - corresponding masks. + - `None` if padding is not used. + - A binary tensor with dimension `batch_size x num_seq` as the mask for the batch. + - A dictionary mapping module names to their corresponding masks for models requiring different + masks for different modules (e.g., encoder-decoder architectures). + """ + + def post_process_per_sample_gradient(self, module_name: str, gradient: torch.Tensor) -> torch.Tensor: + """Post-processes the per-sample gradient of a specific module. + + This method is called only if `do_post_process_per_sample_gradient` is set to `True`. + Override this method in subclasses to implement custom gradient post-processing. + + Args: + module_name (str): + The name of the module whose gradient is being processed. + gradient (torch.Tensor): + The per-sample gradient tensor with dimension `batch_size x gradient_dim x activation_dim`. + + Returns: + torch.Tensor: + The modified per-sample gradient tensor. """ + del module_name + return gradient diff --git a/kronfluence/utils/common/factor_arguments.py b/kronfluence/utils/common/factor_arguments.py index 122baea..26861f5 100644 --- a/kronfluence/utils/common/factor_arguments.py +++ b/kronfluence/utils/common/factor_arguments.py @@ -4,26 +4,26 @@ def default_factor_arguments(strategy: str = "ekfac") -> FactorArguments: - """Default factor arguments.""" + """Creates default factor arguments""" factor_args = FactorArguments(strategy=strategy) return factor_args -def test_factor_arguments(strategy: str = "ekfac") -> FactorArguments: - """Factor arguments used for unit tests.""" +def pytest_factor_arguments(strategy: str = "ekfac") -> FactorArguments: + """Creates factor arguments for unit tests""" factor_args = FactorArguments(strategy=strategy) factor_args.use_empirical_fisher = True factor_args.activation_covariance_dtype = torch.float64 factor_args.gradient_covariance_dtype = torch.float64 factor_args.per_sample_gradient_dtype = torch.float64 - factor_args.lambda_dtype = torch.float32 + factor_args.lambda_dtype = torch.float64 return factor_args def smart_low_precision_factor_arguments( strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 ) -> FactorArguments: - """Factor arguments with low precision, except for the lambda computations.""" + """Creates factor arguments with low precision, except for Lambda computations.""" factor_args = FactorArguments(strategy=strategy) factor_args.amp_dtype = dtype factor_args.activation_covariance_dtype = dtype @@ -34,7 +34,7 @@ def smart_low_precision_factor_arguments( def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: - """Factor arguments with low precision.""" + """Creates factor arguments with low precision for all computations.""" factor_args = FactorArguments(strategy=strategy) factor_args.amp_dtype = dtype factor_args.activation_covariance_dtype = dtype @@ -45,27 +45,18 @@ def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dty def reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: - """Factor arguments with low precision + iterative lambda update.""" + """Creates factor arguments with low precision and iterative lambda aggregations.""" factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) - factor_args.lambda_iterative_aggregate = True + factor_args.use_iterative_lambda_aggregation = True return factor_args def extreme_reduce_memory_factor_arguments( - strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 + strategy: str = "ekfac", module_partitions: int = 1, dtype: torch.dtype = torch.bfloat16 ) -> FactorArguments: - """Factor arguments for models that is difficult to fit in a single GPU.""" - factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) - factor_args.lambda_iterative_aggregate = True - factor_args.cached_activation_cpu_offload = True - factor_args.covariance_module_partition_size = 4 - factor_args.lambda_module_partition_size = 4 - return factor_args - - -def large_dataset_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: - """Factor arguments for large models and datasets.""" - factor_args = smart_low_precision_factor_arguments(strategy=strategy, dtype=dtype) - factor_args.covariance_data_partition_size = 4 - factor_args.lambda_data_partition_size = 4 + """Creates factor arguments for models that are difficult to fit on a single GPU.""" + factor_args = reduce_memory_factor_arguments(strategy=strategy, dtype=dtype) + factor_args.offload_activations_to_cpu = True + factor_args.covariance_module_partitions = module_partitions + factor_args.lambda_module_partitions = module_partitions return factor_args diff --git a/kronfluence/utils/common/score_arguments.py b/kronfluence/utils/common/score_arguments.py index e674432..40675f6 100644 --- a/kronfluence/utils/common/score_arguments.py +++ b/kronfluence/utils/common/score_arguments.py @@ -6,91 +6,79 @@ def default_score_arguments( - damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None + damping_factor: Optional[float] = 1e-08, query_gradient_low_rank: Optional[int] = None ) -> ScoreArguments: - """Default score arguments.""" - score_args = ScoreArguments(damping=damping) - score_args.query_gradient_rank = query_gradient_rank - if score_args.query_gradient_rank is not None: - score_args.num_query_gradient_accumulations = 10 + """Creates default score arguments""" + score_args = ScoreArguments(damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank) + if score_args.query_gradient_low_rank is not None: + score_args.query_gradient_accumulation_steps = 10 return score_args -def test_score_arguments(damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None) -> ScoreArguments: - """Score arguments used for unit tests.""" - score_args = ScoreArguments(damping=damping) +def pytest_score_arguments( + damping_factor: Optional[float] = 1e-08, query_gradient_low_rank: Optional[int] = None +) -> ScoreArguments: + """Creates score arguments for unit tests""" + score_args = ScoreArguments(damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank) score_args.query_gradient_svd_dtype = torch.float64 score_args.score_dtype = torch.float64 score_args.per_sample_gradient_dtype = torch.float64 score_args.precondition_dtype = torch.float64 - score_args.query_gradient_rank = query_gradient_rank return score_args def smart_low_precision_score_arguments( - damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 + damping_factor: Optional[float] = 1e-08, + query_gradient_low_rank: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, ) -> ScoreArguments: - """Score arguments with low precision, except for the preconditioning computations.""" - score_args = ScoreArguments(damping=damping) + """Creates score arguments with low precision, except for preconditioning computations.""" + score_args = default_score_arguments(damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank) score_args.amp_dtype = dtype - score_args.query_gradient_svd_dtype = torch.float32 score_args.score_dtype = dtype score_args.per_sample_gradient_dtype = dtype + score_args.query_gradient_svd_dtype = torch.float32 score_args.precondition_dtype = torch.float32 - score_args.query_gradient_rank = query_gradient_rank - if score_args.query_gradient_rank is not None: - score_args.num_query_gradient_accumulations = 10 return score_args def all_low_precision_score_arguments( - damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 + damping_factor: Optional[float] = 1e-08, + query_gradient_low_rank: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, ) -> ScoreArguments: - """Score arguments with low precision.""" - score_args = ScoreArguments(damping=damping) + """Creates score arguments with low precision for all computations.""" + score_args = default_score_arguments(damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank) score_args.amp_dtype = dtype - score_args.query_gradient_svd_dtype = torch.float32 score_args.score_dtype = dtype score_args.per_sample_gradient_dtype = dtype score_args.precondition_dtype = dtype - score_args.query_gradient_rank = query_gradient_rank - if score_args.query_gradient_rank is not None: - score_args.num_query_gradient_accumulations = 10 + score_args.query_gradient_svd_dtype = torch.float32 return score_args def reduce_memory_score_arguments( - damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 + damping_factor: Optional[float] = 1e-08, + query_gradient_low_rank: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, ) -> ScoreArguments: - """Score arguments with low precision + CPU offload.""" - score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype) - score_args.cached_activation_cpu_offload = True - score_args.query_gradient_rank = query_gradient_rank - if score_args.query_gradient_rank is not None: - score_args.num_query_gradient_accumulations = 10 + """Creates score arguments with low precision and CPU offloading.""" + score_args = all_low_precision_score_arguments( + damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank, dtype=dtype + ) + score_args.offload_activations_to_cpu = True return score_args def extreme_reduce_memory_score_arguments( - damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 -) -> ScoreArguments: - """Score arguments for models that is difficult to fit in a single GPU.""" - score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype) - score_args.cached_activation_cpu_offload = True - score_args.query_gradient_rank = query_gradient_rank - score_args.module_partition_size = 4 - if score_args.query_gradient_rank is not None: - score_args.num_query_gradient_accumulations = 10 - return score_args - - -def large_dataset_score_arguments( - damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 + damping_factor: Optional[float] = 1e-08, + module_partitions: int = 4, + query_gradient_low_rank: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, ) -> ScoreArguments: - """Score arguments for large models and datasets.""" - score_args = smart_low_precision_score_arguments(damping=damping, dtype=dtype) - score_args.data_partition_size = 4 - score_args.query_gradient_rank = query_gradient_rank - if score_args.query_gradient_rank is not None: - score_args.num_query_gradient_accumulations = 10 + """Creates score arguments for models that are difficult to fit on a single GPU.""" + score_args = reduce_memory_score_arguments( + damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank, dtype=dtype + ) + score_args.module_partitions = module_partitions return score_args diff --git a/kronfluence/utils/constants.py b/kronfluence/utils/constants.py index 37c55f9..4a86ed3 100644 --- a/kronfluence/utils/constants.py +++ b/kronfluence/utils/constants.py @@ -1,12 +1,22 @@ """A collection of constants.""" -from typing import Dict, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch FACTOR_TYPE = Dict[str, Dict[str, torch.Tensor]] PARTITION_TYPE = Tuple[int, int] SCORE_TYPE = Dict[str, torch.Tensor] +PRECONDITIONED_GRADIENT_TYPE = Optional[Union[torch.Tensor, List[torch.Tensor]]] + +# Constants for file naming conventions. +FACTOR_SAVE_PREFIX = "factors_" +SCORE_SAVE_PREFIX = "scores_" +FACTOR_ARGUMENTS_NAME = "factor" +SCORE_ARGUMENTS_NAME = "score" + +# The total iteration step to synchronize the process when using distributed setting. +DISTRIBUTED_SYNC_INTERVAL = 1_000 # Activation covariance matrix. ACTIVATION_COVARIANCE_MATRIX_NAME = "activation_covariance" @@ -53,11 +63,14 @@ # Preconditioned per-sample gradient. PRECONDITIONED_GRADIENT_NAME = "preconditioned_gradient" -ACCUMULATED_PRECONDITIONED_GRADIENT_NAME = "aggregated_preconditioned_gradient" +# Accumulated preconditioned per-sample gradient. +ACCUMULATED_PRECONDITIONED_GRADIENT_NAME = "accumulated_preconditioned_gradient" +# Aggregated gradient. +AGGREGATED_GRADIENT_NAME = "aggregated_gradient" # Pairwise influence scores. PAIRWISE_SCORE_MATRIX_NAME = "pairwise_score_matrix" # Self-influence scores. SELF_SCORE_VECTOR_NAME = "self_score_vector" -# The dictionary key for storing scores for all modules. +# The dictionary key for storing summed scores. ALL_MODULE_NAME = "all_modules" diff --git a/kronfluence/utils/dataset.py b/kronfluence/utils/dataset.py index 920cf55..9f33b22 100644 --- a/kronfluence/utils/dataset.py +++ b/kronfluence/utils/dataset.py @@ -16,8 +16,11 @@ @dataclass class DataLoaderKwargs(KwargsHandler): - """The object used to customize `DataLoader`. Please refer to https://pytorch.org/docs/stable/data.html for - detailed information of each argument. The default arguments are copied from PyTorch version 2.3. + """Customization options for DataLoader. + + This class encapsulates the arguments used to customize PyTorch's DataLoader. Default values are based on + PyTorch version 2.3. For detailed information on each argument, refer to: + https://pytorch.org/docs/stable/data.html. """ num_workers: int = 0 @@ -33,9 +36,23 @@ class DataLoaderKwargs(KwargsHandler): def make_indices_partition(total_data_examples: int, partition_size: int) -> List[Tuple[int, int]]: - """Returns partitioned indices from the total data examples.""" + """Partitions data indices into approximately equal-sized bins. + + Args: + total_data_examples (int): + Total number of data examples. + partition_size (int): + Number of partitions to create. + + Returns: + List[Tuple[int, int]]: + List of tuples, each containing start and end indices for a partition. + + Raises: + ValueError: If `total_data_examples` is less than `partition_size`. + """ if total_data_examples < partition_size: - raise ValueError("The total data examples must be equal or greater than the partition size.") + raise ValueError("The total data examples must be equal to or greater than the partition size.") # See https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length. bins = list(map(len, np.array_split(range(total_data_examples), partition_size))) start_idx = 0 @@ -47,8 +64,26 @@ def make_indices_partition(total_data_examples: int, partition_size: int) -> Lis def find_executable_batch_size(func: Callable, start_batch_size: int) -> int: - """Finds executable batch size for calling the function that does not encounter OOM error. The code is motivated - from https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/memory.py#L83.""" + """Finds the largest batch size that can be executed without OOM errors. + + This function progressively reduces the batch size until it finds a size that can be executed + without running out of memory. The code is motivated from: + https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/memory.py#L83 + + Args: + func (Callable): + Function to test with different batch sizes. + start_batch_size (int): + Initial batch size to try. + + Returns: + int: + The largest executable batch size. + + Raises: + RuntimeError: + If no executable batch size is found (reaches zero). + """ batch_size = start_batch_size while True: @@ -67,16 +102,15 @@ def find_executable_batch_size(func: Callable, start_batch_size: int) -> int: class DistributedEvalSampler(Sampler[T_co]): - """DistributedEvalSampler is different from `DistributedSampler`: it does not add extra samples to make - the dataset evenly divisible. DistributedEvalSampler should not be used for training; the distributed processes - could hang forever. See this issue for details: https://github.com/pytorch/pytorch/issues/22584. + """Sampler for distributed setting without adding extra samples. + Unlike `DistributedSampler`, it does not add extra samples to make the dataset evenly divisible across processes. The code is adapted from https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py. """ def __init__( # pylint: disable=super-init-not-called self, - dataset: torch.utils.data.Dataset, + dataset: data.Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, @@ -112,10 +146,10 @@ def __len__(self) -> int: class DistributedSamplerWithStack(Sampler[T_co]): - """DistributedSampleWithStack is different from `DistributedSampler`. Instead of subsampling, - it stacks the dataset. For example, when `num_replicas` is 3, and the dataset of [0, ..., 9] is given, - the first, second, and third rank should have [0, 1, 2], [3, 4, 5], and [6, 7, 8], respectively. However, - it still adds extra samples to make the dataset evenly divisible (different from DistributedEvalSampler). + """Sampler that stacks the dataset for distributed setting. + + Instead of subsampling, this sampler stacks the dataset across processes. It ensures even distribution by + adding padding samples if necessary. """ def __init__( # pylint: disable=super-init-not-called diff --git a/kronfluence/utils/exceptions.py b/kronfluence/utils/exceptions.py index 597ca5b..e64f334 100644 --- a/kronfluence/utils/exceptions.py +++ b/kronfluence/utils/exceptions.py @@ -3,12 +3,12 @@ class FactorsNotFoundError(ValueError): class TrackedModuleNotFoundError(ValueError): - """Exception raised when the tracked module is not found.""" + """Exception raised when a tracked module is not found in the model.""" class IllegalTaskConfigurationError(ValueError): - """Exception raised when the provided task is determined to be invalid.""" + """Exception raised when the provided task configuration is determined to be invalid.""" class UnsupportableModuleError(NotImplementedError): - """Exception raised when the provided module is not supported.""" + """Exception raised when the provided module is not supported by the current implementation.""" diff --git a/kronfluence/utils/logger.py b/kronfluence/utils/logger.py index 04af511..2ca1a0f 100644 --- a/kronfluence/utils/logger.py +++ b/kronfluence/utils/logger.py @@ -20,21 +20,33 @@ class MultiProcessAdapter(logging.LoggerAdapter): - """An adapter to assist with logging in multiprocess. + """An adapter for logging in multiprocess environments. - The code is copied from https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py with - minor modifications. + The code is adapted from: https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py. """ def log(self, level: int, msg: str, *args, **kwargs) -> None: - """Delegates logger call after checking if it should log.""" + """Log a message if logging is enabled for this process.""" if self.isEnabledFor(level) and not self.extra["disable_log"]: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) def get_logger(name: str, disable_log: bool = False, log_level: int = None) -> MultiProcessAdapter: - """Returns the logger with an option to disable logging.""" + """Creates and returns a logger with optional disabling and log level setting. + + Args: + name (str): + Name of the logger. + disable_log (bool): + Whether to disable logging. Defaults to `False`. + log_level (int): + Logging level to set. Defaults to `None`. + + Returns: + MultiProcessAdapter: + Configured logger adapter. + """ logger = logging.getLogger(name) if log_level is not None: logger.setLevel(log_level) @@ -43,16 +55,15 @@ def get_logger(name: str, disable_log: bool = False, log_level: int = None) -> M class Profiler: - """Profiling object to measure the time taken to run a certain operation. The profiler is helpful - for checking any bottlenecks in the code. + """A profiling utility to measure execution time of operations. - The code is modified from: + The code is adapted from: - https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers. - https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/profiler.py. """ def __init__(self, state: State) -> None: - """Initializes an instance of the Profiler class. + """Initializes an instance of the `Profiler` class. Args: state (State): @@ -63,7 +74,7 @@ def __init__(self, state: State) -> None: self.recorded_durations = defaultdict(list) def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" + """Start recording an action.""" if not self.state.is_main_process: return if action_name in self.current_actions: @@ -71,7 +82,7 @@ def start(self, action_name: str) -> None: self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: - """Defines how to record the duration once an action is complete.""" + """Stop recording an action and log its duration.""" if not self.state.is_main_process: return end_time = _get_monotonic_time() @@ -83,7 +94,7 @@ def stop(self, action_name: str) -> None: @contextmanager def profile(self, action_name: str) -> Generator: - """Yields a context manager to encapsulate the scope of a profiled action.""" + """Context manager for profiling an action.""" try: self.start(action_name) yield action_name @@ -92,6 +103,7 @@ def profile(self, action_name: str) -> Generator: @torch.no_grad() def _make_report(self) -> Tuple[_TABLE_DATA, float, float]: + """Generate a report of profiled actions.""" total_duration = 0.0 for a, d in self.recorded_durations.items(): d_tensor = torch.tensor(d, dtype=torch.float64, requires_grad=False) @@ -110,7 +122,7 @@ def _make_report(self) -> Tuple[_TABLE_DATA, float, float]: return report, total_calls, total_duration def summary(self) -> str: - """Returns a formatted summary for the Profiler.""" + """Generate a formatted summary of the profiling results.""" sep = os.linesep output_string = "Profiler Report:" @@ -143,37 +155,34 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str class PassThroughProfiler(Profiler): - """A pass through Profiler objective that does not record timing.""" + """A no-op profiler that doesn't record any timing information.""" def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" return def stop(self, action_name: str) -> None: - """Defines how to record the duration once an action is complete.""" return def summary(self) -> str: - """Returns a formatted summary for the Profiler.""" return "" class TorchProfiler(Profiler): - """A PyTorch Profiler objective that provides detailed profiling information: - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html. + """A profiler that utilizes PyTorch's built-in profiling capabilities. + + This profiler provides detailed information about PyTorch operations, including CPU and CUDA events. + It's useful for low-level profiling in PyTorch. - This is useful for low-level profiling in PyTorch, and is not used by default. + Note: This is not used by default and is intended for detailed performance analysis. """ def __init__(self, state: State) -> None: - """Initializes an instance of the PyTorch Profiler class.""" super().__init__(state=state) self.actions: list = [] self.trace_outputs: list = [] self._set_up_torch_profiler() def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" if action_name in self.current_actions: raise ValueError(f"Attempted to start {action_name} which has already started.") # Set dummy value, since only used to track duplicate actions. @@ -182,14 +191,12 @@ def start(self, action_name: str) -> None: self._torch_prof.start() def stop(self, action_name: str) -> None: - """Defines how to stop recording an action.""" if action_name not in self.current_actions: raise ValueError(f"Attempting to stop recording an action " f"({action_name}) which was never started.") _ = self.current_actions.pop(action_name) self._torch_prof.stop() def _set_up_torch_profiler(self) -> None: - """Creates the PyTorch profiler object with the necessary arguments.""" self._torch_prof = t_prof.profile( activities=[t_prof.ProfilerActivity.CPU, t_prof.ProfilerActivity.CUDA], record_shapes=True, @@ -200,7 +207,6 @@ def _set_up_torch_profiler(self) -> None: ) def _trace_handler(self, p) -> None: - """Adds the PyTorch Profiler trace output to a list once it is ready.""" # Set metric to sort based on device. is_cpu = self.state.device == torch.device("cpu") sort_by_metric = "self_cpu_time_total" if is_cpu else "self_cuda_time_total" @@ -218,12 +224,10 @@ def _trace_handler(self, p) -> None: self.recorded_durations[self.actions[-1]].append(total_time) def _reset_output(self) -> None: - """Resets actions and outputs list.""" self.actions = [] self.trace_outputs = [] def _high_level_summary(self) -> str: - """Returns a formatted high level summary for the PyTorch Profiler.""" sep = os.linesep output_string = "Overall PyTorch Profiler Report:" @@ -255,7 +259,6 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str return output_string def summary(self) -> str: - """Returns a formatted summary for the PyTorch Profiler.""" assert len(self.actions) == len(self.trace_outputs), ( "Mismatch in the number of actions and outputs collected: " + f"# Actions: {len(self.actions)}, # Ouptuts: {len(self.trace_outputs)}" @@ -272,10 +275,15 @@ def summary(self) -> str: return summary -# Timing utilities copied from +# Timing utilities copied from: # https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/pytorch_utils.py. def _get_monotonic_time() -> float: - """Gets the monotonic time after the CUDA synchronization if necessary.""" + """Gets the time after the CUDA synchronization. + + Returns: + float: + The current time. + """ if torch.cuda.is_available() and torch.cuda.is_initialized(): torch.cuda.synchronize() return time.monotonic() @@ -283,7 +291,16 @@ def _get_monotonic_time() -> float: @torch.no_grad() def get_time(state: State) -> float: - """Gets the current time after synchronizing with other devices.""" + """Gets the current time after synchronizing with other devices. + + Args: + state (State): + The current process's information (e.g., device being used). + + Returns: + float: + The current time. + """ if not state.use_distributed: if torch.cuda.is_available() and torch.cuda.is_initialized(): torch.cuda.synchronize() diff --git a/kronfluence/utils/model.py b/kronfluence/utils/model.py index 1a71a79..e58bd33 100644 --- a/kronfluence/utils/model.py +++ b/kronfluence/utils/model.py @@ -20,21 +20,25 @@ def apply_ddp( rank: int, world_size: int, ) -> DistributedDataParallel: - """Applies DistributedDataParallel (DDP) to the given model. + """Applies DistributedDataParallel (DDP) to the given PyTorch model. Args: model (nn.Module): - The model for which DDP will be applied. + The PyTorch model to be parallelized. local_rank (int): - The local rank of the current process. + The local rank of the current process within its node. rank (int): - The rank of the current process. + The global rank of the current process across all nodes. world_size (int): - The total number of processes. + The total number of processes in the distributed setup. Returns: DistributedDataParallel: - The model wrapped with DDP. + The input model wrapped with DDP. + + Raises: + RuntimeError: + If the distributed initialization fails. """ dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device(f"cuda:{local_rank}") @@ -61,31 +65,35 @@ def apply_fsdp( is_transformer: bool = False, layer_to_wrap: Optional[nn.Module] = None, ) -> FSDP: - """Applies FullyShardedDataParallel (FSDP) to the given model. + """Applies FullyShardedDataParallel (FSDP) to the given PyTorch model. Args: model (nn.Module): - The model for which FSDP will be applied. + The PyTorch model to be parallelized. local_rank (int): - The local rank of the current process. + The local rank of the current process within its node. rank (int): - The rank of the current process. + The global rank of the current process across all nodes. world_size (int): - The total number of processes. + The total number of processes in the distributed setup. sharding_strategy (str): - The sharding strategy to use. Defaults to "FULL_SHARD". + The FSDP sharding strategy to use. Defaults to "FULL_SHARD". cpu_offload (bool): - Whether to offload parameters to CPU. Check - https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.CPUOffload. Defaults to True. + Whether to offload parameters to CPU. Defaults to `True`. is_transformer (bool): - Whether the model is a transformer model. Defaults to False. + Whether the model is a transformer. Defaults to `False`. layer_to_wrap (nn.Module, optional): - The specific layer to wrap for transformer models. Required if `is_transformer` is True. - Defaults to None. + The specific layer to wrap for transformer models. Required if `is_transformer` is `True`. Returns: FullyShardedDataParallel: - The model wrapped with FSDP. + The input model wrapped with FSDP. + + Raises: + ValueError: + If an invalid sharding strategy is provided or if `layer_to_wrap` is not provided for transformer models. + RuntimeError: + If the distributed initialization fails. """ dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device(f"cuda:{local_rank}") diff --git a/kronfluence/utils/save.py b/kronfluence/utils/save.py index bb0566a..f129551 100644 --- a/kronfluence/utils/save.py +++ b/kronfluence/utils/save.py @@ -5,38 +5,87 @@ import torch from safetensors import safe_open -FACTOR_SAVE_PREFIX = "factors_" -SCORE_SAVE_PREFIX = "scores_" -FACTOR_ARGUMENTS_NAME = "factor" -SCORE_ARGUMENTS_NAME = "score" +def load_file(path: Path) -> Dict[str, torch.Tensor]: + """Loads a dictionary of tensors from a file using `safetensors`. + Args: + path (Path): + The path to the file containing tensor data. -def load_file(path: Path) -> Dict[str, torch.Tensor]: - """Loads a dictionary of tensors from the path.""" - load_dict = {} - with safe_open(path, framework="pt", device="cpu") as f: - for key in f.keys(): - load_dict[key] = f.get_tensor(name=key) - return load_dict + Returns: + Dict[str, torch.Tensor]: + A dictionary where keys are tensor names and values are the corresponding tensors. + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}.") + try: + with safe_open(path, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} + except Exception as e: + raise RuntimeError(f"Error loading file {path}: {str(e)}") from e def save_json(obj: Any, path: Path) -> None: - """Saves the object to a JSON file.""" - with open(path, "w", encoding="utf-8") as f: - json.dump(obj, f, indent=4) + """Saves an object to a JSON file. + + This function serializes the given object to JSON format and writes it to a file. + + Args: + obj (Any): + The object to be saved. Must be JSON-serializable. + path (Path): + The path where the JSON file will be saved. + """ + path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=4, ensure_ascii=False) + except TypeError as e: + raise TypeError(f"Object is not JSON-serializable: {str(e)}") from e + except Exception as e: + raise IOError(f"Error saving JSON file {path}: {str(e)}") from e def load_json(path: Path) -> Dict[str, Any]: - """Loads an object from the JSON file.""" - with open(path, "rb") as f: - obj = json.load(f) - return obj + """Loads an object from a JSON file. + + Args: + path (Path): + The path to the JSON file to be loaded. + + Returns: + Dict[str, Any]: + The object loaded from the JSON file. + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}.") + with open(path, "r", encoding="utf-8") as f: + return json.load(f) -@torch.no_grad() def verify_models_equivalence(state_dict1: Dict[str, torch.Tensor], state_dict2: Dict[str, torch.Tensor]) -> bool: - """Checks if two models are equivalent given their `state_dict`.""" + """Check if two models are equivalent given their state dictionaries. + + This function compares two model state dictionaries to determine if they represent + equivalent models. It checks for equality in the number of parameters, parameter names, + and parameter values (within a small tolerance). + + Args: + state_dict1 (Dict[str, torch.Tensor]): + The state dictionary of the first model. + state_dict2 (Dict[str, torch.Tensor]): + The state dictionary of the second model. + + Returns: + bool: + `True` if the models are equivalent, `False` otherwise. + + Notes: + - The function uses a relative tolerance of 1.3e-6 and an absolute tolerance of 1e-5 + when comparing tensor values. + - Tensors are compared in float32 precision on the CPU to ensure consistency. + """ if len(state_dict1) != len(state_dict2): return False @@ -46,7 +95,7 @@ def verify_models_equivalence(state_dict1: Dict[str, torch.Tensor], state_dict2: for name in state_dict1: tensor1 = state_dict1[name].to(dtype=torch.float32).cpu() tensor2 = state_dict2[name].to(dtype=torch.float32).cpu() - if not torch.allclose(tensor1, tensor2, rtol=1e-3, atol=1e-5): + if not torch.allclose(tensor1, tensor2, rtol=1.3e-6, atol=1e-5): return False return True diff --git a/kronfluence/utils/state.py b/kronfluence/utils/state.py index 08705dd..73d5d1e 100644 --- a/kronfluence/utils/state.py +++ b/kronfluence/utils/state.py @@ -1,13 +1,12 @@ import contextlib import gc import os -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List import torch import torch.distributed as dist from accelerate.state import SharedDict from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel class State: @@ -23,12 +22,11 @@ class State: _shared_state: Dict[str, Any] = SharedDict() def __init__(self, cpu: bool = False) -> None: - """Initializes an instance of the State class. + """Initializes an instance of the `State` class. Args: cpu (bool): - Specifies whether the analysis should be explicitly performed using the CPU. - Defaults to False, utilizing GPU resources if available. + If `True`, forces the use of CPU even if GPUs are available. Defaults to `False`. """ self.__dict__ = self._shared_state @@ -51,6 +49,12 @@ def __init__(self, cpu: bool = False) -> None: self.device = torch.device("cpu") if self.cpu else self.default_device def __repr__(self) -> str: + """Provides a string representation of the `State` instance. + + Returns: + str: + A formatted string containing process and device information. + """ return ( f"Num processes: {self.num_processes}\n" f"Process index: {self.process_index}\n" @@ -60,64 +64,101 @@ def __repr__(self) -> str: @staticmethod def _reset_state() -> None: - """Resets `_shared_state`, is used internally and should not be called.""" + """Resets the shared state. For internal use only.""" State._shared_state.clear() @property def initialized(self) -> bool: - """Returns whether the `PartialState` has been initialized.""" + """Checks if the `State` has been initialized.""" return self._shared_state != {} @property def use_distributed(self) -> bool: - """Whether the State is configured for distributed training.""" + """Checks if the setup is configured for distributed setting.""" return self.num_processes > 1 @property def is_main_process(self) -> bool: - """Returns whether the current process is the main process.""" + """Checks if the current process is the main process.""" return self.process_index == 0 @property def is_local_main_process(self) -> bool: - """Returns whether the current process is the main process on the local node.""" + """Checks if the current process is the main process on the local node.""" return self.local_process_index == 0 @property def is_last_process(self) -> bool: - """Returns whether the current process is the last one.""" + """Checks if the current process is the last one.""" return self.process_index == self.num_processes - 1 def wait_for_everyone(self) -> None: - """Will stop the execution of the current process until every other process has reached that point - (so this does nothing when the script is only run in one process).""" + """Synchronizes all processes. + + This method will pause the execution of the current process until all other processes + reach this point. It has no effect in single-process execution. + """ if self.use_distributed: dist.barrier() @property def default_device(self) -> torch.device: - """Finds the default device currently available.""" + """Determines the default device (CUDA if available, otherwise CPU). + + Returns: + torch.device: + The default device. + """ if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def release_memory() -> None: - """Releases the memory by calling `gc.collect()` and `torch.cuda.empty_cache()`.""" + """Releases unused memory. + + This function calls Python's garbage collector and empties CUDA cache if CUDA is available. + """ gc.collect() + torch.compiler.reset() if torch.cuda.is_available(): torch.cuda.empty_cache() +def get_active_tensors() -> List[torch.Tensor]: + """Gets a list of active tensors in memory. + + Returns: + List[torch.Tensor]: + A list of tuples containing tensor type and size. + """ + tensor_lst = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)): + tensor_lst.append(type(obj), obj.size()) + return tensor_lst + + @contextlib.contextmanager def no_sync(model: nn.Module, state: State) -> Callable: - """A context manager to avoid DDP synchronization. The code is adapted from - https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L852.""" + """A context manager to temporarily disable gradient synchronization in distributed setting. + + Args: + model (nn.Module): + The PyTorch model. + state (State): + The current process state. + + Yields: + A context where gradient synchronization is disabled (if applicable). + + Note: + For FullyShardedDataParallel (FSDP) models, this may result in higher memory usage. + See: https://pytorch.org/docs/stable/fsdp.html. + """ context = contextlib.nullcontext - # `no_sync()` for FSDP instance can result in higher memory usage, detailed in: - # https://pytorch.org/docs/stable/fsdp.html. - if state.use_distributed and not isinstance(model, FullyShardedDataParallel): + if state.use_distributed: context = getattr(model, "no_sync", context) with context(): diff --git a/kronfluence/version.py b/kronfluence/version.py index 3dc1f76..5becc17 100644 --- a/kronfluence/version.py +++ b/kronfluence/version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "1.0.0" diff --git a/pyproject.toml b/pyproject.toml index 5f94bb7..870cc40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,5 @@ disable = """ implicit-str-concat, inconsistent-return-statements, too-many-lines, + too-many-public-methods, """ \ No newline at end of file diff --git a/tests/factors/test_covariances.py b/tests/factors/test_covariances.py index 4ab3240..8603f0e 100644 --- a/tests/factors/test_covariances.py +++ b/tests/factors/test_covariances.py @@ -5,7 +5,7 @@ from kronfluence.utils.common.factor_arguments import ( default_factor_arguments, - test_factor_arguments, + pytest_factor_arguments, ) from kronfluence.utils.constants import ( ACTIVATION_COVARIANCE_MATRIX_NAME, @@ -17,8 +17,10 @@ from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( ATOL, + DEFAULT_FACTORS_NAME, RTOL, check_tensor_dict_equivalence, + custom_factors_name, prepare_model_and_analyzer, prepare_test, ) @@ -29,15 +31,15 @@ [ "mlp", "repeated_mlp", - "mlp_checkpoint", "conv", - "conv_bn", "bert", + "roberta", "gpt", + "gpt_checkpoint", ], ) -@pytest.mark.parametrize("activation_covariance_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("gradient_covariance_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("activation_covariance_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("gradient_covariance_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("train_size", [16]) @pytest.mark.parametrize("seed", [0]) def test_fit_covariance_matrices( @@ -58,24 +60,23 @@ def test_fit_covariance_matrices( model=model, task=task, ) - factor_args = default_factor_arguments() factor_args.activation_covariance_dtype = activation_covariance_dtype factor_args.gradient_covariance_dtype = gradient_covariance_dtype - factors_name = f"pytest_{test_name}_{test_fit_covariance_matrices.__name__}" analyzer.fit_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, + factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, - factor_args=factor_args, per_device_batch_size=train_size // 4, overwrite_output_dir=True, ) covariance_factors = analyzer.load_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, ) assert set(covariance_factors.keys()) == set(COVARIANCE_FACTOR_NAMES) assert len(covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME]) > 0 + assert len(covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME]) > 0 for module_name in covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME]: assert covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME][module_name].dtype == activation_covariance_dtype assert covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME][module_name].dtype == gradient_covariance_dtype @@ -103,35 +104,32 @@ def test_covariance_matrices_batch_size_equivalence( seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_batch_size_equivalence.__name__}_bs1", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=1, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - bs1_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_batch_size_equivalence.__name__}_bs1" - ) + bs1_covariance_factors = analyzer.load_covariance_matrices(factors_name=DEFAULT_FACTORS_NAME) analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_batch_size_equivalence.__name__}_bs8", + factors_name=custom_factors_name(name="bs8"), dataset=train_dataset, factor_args=factor_args, per_device_batch_size=8, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - bs8_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_batch_size_equivalence.__name__}_bs8" - ) + bs8_covariance_factors = analyzer.load_covariance_matrices(factors_name=custom_factors_name(name="bs8")) for name in COVARIANCE_FACTOR_NAMES: assert check_tensor_dict_equivalence( @@ -146,17 +144,18 @@ def test_covariance_matrices_batch_size_equivalence( "test_name", [ "mlp", - "conv", + "conv_bn", + "bert", ], ) -@pytest.mark.parametrize("data_partition_size", [2, 4]) -@pytest.mark.parametrize("module_partition_size", [2, 3]) +@pytest.mark.parametrize("data_partitions", [2, 4]) +@pytest.mark.parametrize("module_partitions", [2, 3]) @pytest.mark.parametrize("train_size", [62]) @pytest.mark.parametrize("seed", [2]) def test_covariance_matrices_partition_equivalence( test_name: str, - data_partition_size: int, - module_partition_size: int, + data_partitions: int, + module_partitions: int, train_size: int, seed: int, ) -> None: @@ -167,27 +166,27 @@ def test_covariance_matrices_partition_equivalence( seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - factor_args = test_factor_arguments() - factors_name = f"pytest_{test_name}_{test_covariance_matrices_partition_equivalence.__name__}" + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=8, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - covariance_factors = analyzer.load_covariance_matrices(factors_name=factors_name) + covariance_factors = analyzer.load_covariance_matrices(factors_name=DEFAULT_FACTORS_NAME) - factor_args.covariance_data_partition_size = data_partition_size - factor_args.covariance_module_partition_size = module_partition_size + factor_args.covariance_data_partitions = data_partitions + factor_args.covariance_module_partitions = module_partitions analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_name}_partitioned_{data_partition_size}_{module_partition_size}", + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), dataset=train_dataset, factor_args=factor_args, per_device_batch_size=7, @@ -195,7 +194,7 @@ def test_covariance_matrices_partition_equivalence( dataloader_kwargs=kwargs, ) partitioned_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_partitioned_{data_partition_size}_{module_partition_size}", + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), ) for name in COVARIANCE_FACTOR_NAMES: @@ -207,7 +206,7 @@ def test_covariance_matrices_partition_equivalence( ) -@pytest.mark.parametrize("test_name", ["bert", "wrong_bert", "gpt"]) +@pytest.mark.parametrize("test_name", ["bert", "wrong_bert", "roberta"]) @pytest.mark.parametrize("train_size", [213]) @pytest.mark.parametrize("seed", [3]) def test_covariance_matrices_attention_mask( @@ -235,17 +234,15 @@ def test_covariance_matrices_attention_mask( seed=seed, ) model = model.to(dtype=torch.float64) - - kwargs = DataLoaderKwargs(collate_fn=data_collator) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) - factor_args = test_factor_arguments() - factors_name = f"pytest_{test_name}_{test_covariance_matrices_attention_mask.__name__}" + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=train_size // 4, @@ -253,11 +250,11 @@ def test_covariance_matrices_attention_mask( dataloader_kwargs=kwargs, ) covariance_factors = analyzer.load_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, ) analyzer.fit_covariance_matrices( - factors_name=factors_name + "_no_pad", + factors_name=custom_factors_name("no_pad"), dataset=no_padded_train_dataset, factor_args=factor_args, per_device_batch_size=1, @@ -265,7 +262,7 @@ def test_covariance_matrices_attention_mask( dataloader_kwargs=kwargs, ) no_padded_covariance_factors = analyzer.load_covariance_matrices( - factors_name=factors_name + "_no_pad", + factors_name=custom_factors_name("no_pad"), ) for name in COVARIANCE_FACTOR_NAMES: @@ -304,25 +301,25 @@ def test_covariance_matrices_automatic_batch_size( seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - factor_args = test_factor_arguments() - factors_name = f"pytest_{test_name}_{test_covariance_matrices_automatic_batch_size.__name__}" + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=8, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - covariance_factors = analyzer.load_covariance_matrices(factors_name=factors_name) + covariance_factors = analyzer.load_covariance_matrices(factors_name=DEFAULT_FACTORS_NAME) analyzer.fit_covariance_matrices( - factors_name=factors_name + "_auto", + factors_name=custom_factors_name("auto"), dataset=train_dataset, factor_args=factor_args, per_device_batch_size=None, @@ -330,7 +327,7 @@ def test_covariance_matrices_automatic_batch_size( dataloader_kwargs=kwargs, ) auto_covariance_factors = analyzer.load_covariance_matrices( - factors_name=factors_name + "_auto", + factors_name=custom_factors_name("auto"), ) for name in COVARIANCE_FACTOR_NAMES: @@ -343,12 +340,16 @@ def test_covariance_matrices_automatic_batch_size( @pytest.mark.parametrize("test_name", ["mlp"]) -@pytest.mark.parametrize("data_partition_size", [1, 4]) +@pytest.mark.parametrize("max_examples", [4, 26]) +@pytest.mark.parametrize("data_partitions", [1, 4]) +@pytest.mark.parametrize("module_partitions", [1, 2]) @pytest.mark.parametrize("train_size", [80]) @pytest.mark.parametrize("seed", [5]) def test_covariance_matrices_max_examples( test_name: str, - data_partition_size: int, + max_examples: int, + data_partitions: int, + module_partitions: int, train_size: int, seed: int, ) -> None: @@ -364,142 +365,147 @@ def test_covariance_matrices_max_examples( task=task, ) - MAX_EXAMPLES = 26 - factor_args = test_factor_arguments() - factor_args.covariance_max_examples = MAX_EXAMPLES - factor_args.covariance_data_partition_size = data_partition_size + factor_args = pytest_factor_arguments() + factor_args.covariance_max_examples = max_examples + factor_args.covariance_data_partitions = data_partitions + factor_args.covariance_module_partitions = module_partitions - factors_name = f"pytest_{test_name}_{test_covariance_matrices_max_examples.__name__}" analyzer.fit_covariance_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=32, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - covariance_factors = analyzer.load_covariance_matrices(factors_name=factors_name) + covariance_factors = analyzer.load_covariance_matrices(factors_name=DEFAULT_FACTORS_NAME) for num_examples in covariance_factors[NUM_ACTIVATION_COVARIANCE_PROCESSED].values(): - assert num_examples == MAX_EXAMPLES + assert num_examples == max_examples for num_examples in covariance_factors[NUM_GRADIENT_COVARIANCE_PROCESSED].values(): - assert num_examples == MAX_EXAMPLES + assert num_examples == max_examples -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "conv", - ], -) -@pytest.mark.parametrize("train_size", [101]) -@pytest.mark.parametrize("seed", [8]) -def test_covariance_matrices_amp( +@pytest.mark.parametrize("test_name", ["mlp", "gpt"]) +@pytest.mark.parametrize("train_size", [100]) +@pytest.mark.parametrize("seed", [6]) +def test_covariance_matrices_gradient_checkpoint( test_name: str, train_size: int, seed: int, ) -> None: - # Covariance matrices should be similar when AMP is enabled. + # Covariance matrices should be the same even when gradient checkpointing is used. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, per_device_batch_size=8, overwrite_output_dir=True, - factor_args=factor_args, dataloader_kwargs=kwargs, + factor_args=factor_args, ) covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}" + factors_name=DEFAULT_FACTORS_NAME, ) - factor_args.amp_dtype = torch.float16 + model, _, _, _, task = prepare_test( + test_name=test_name + "_checkpoint", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}_amp", + factors_name=custom_factors_name("cp"), dataset=train_dataset, - per_device_batch_size=8, + per_device_batch_size=4, + dataloader_kwargs=kwargs, overwrite_output_dir=True, factor_args=factor_args, - dataloader_kwargs=kwargs, ) - amp_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}_amp" + checkpoint_covariance_factors = analyzer.load_covariance_matrices( + factors_name=custom_factors_name("cp"), ) - for name in COVARIANCE_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - covariance_factors[name], - amp_covariance_factors[name], - atol=1e-01, - rtol=1e-02, - ) + assert check_tensor_dict_equivalence( + covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], + checkpoint_covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], + atol=ATOL, + rtol=RTOL, + ) @pytest.mark.parametrize("train_size", [100]) -@pytest.mark.parametrize("seed", [12]) -def test_covariance_matrices_gradient_checkpoint( +@pytest.mark.parametrize("seed", [7, 8]) +def test_covariance_matrices_inplace( train_size: int, seed: int, ) -> None: - # Covariance matrices should be the same even when gradient checkpointing is used. + # Covariance matrices should be the identical for with and without in-place ReLU. model, train_dataset, _, data_collator, task = prepare_test( - test_name="mlp", + test_name="conv", train_size=train_size, seed=seed, ) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_covariance_matrices_gradient_checkpoint.__name__}", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, per_device_batch_size=8, overwrite_output_dir=True, factor_args=factor_args, ) covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_covariance_matrices_gradient_checkpoint.__name__}", + factors_name=DEFAULT_FACTORS_NAME, ) model, _, _, _, task = prepare_test( - test_name="mlp_checkpoint", + test_name="conv_inplace", train_size=train_size, seed=seed, ) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) analyzer.fit_covariance_matrices( - factors_name=f"pytest_{test_covariance_matrices_gradient_checkpoint.__name__}_cp", + factors_name=custom_factors_name("inplace"), dataset=train_dataset, per_device_batch_size=4, overwrite_output_dir=True, factor_args=factor_args, ) - checkpoint_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_covariance_matrices_gradient_checkpoint.__name__}_cp", + inplace_covariance_factors = analyzer.load_covariance_matrices( + factors_name=custom_factors_name("inplace"), ) assert check_tensor_dict_equivalence( covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], - checkpoint_covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], + inplace_covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], atol=ATOL, rtol=RTOL, ) diff --git a/tests/factors/test_eigendecompositions.py b/tests/factors/test_eigendecompositions.py new file mode 100644 index 0000000..3198754 --- /dev/null +++ b/tests/factors/test_eigendecompositions.py @@ -0,0 +1,66 @@ +# pylint: skip-file + +import pytest +import torch + +from kronfluence.arguments import FactorArguments +from kronfluence.utils.constants import ( + ACTIVATION_EIGENVECTORS_NAME, + EIGENDECOMPOSITION_FACTOR_NAMES, + GRADIENT_EIGENVECTORS_NAME, +) +from kronfluence.utils.dataset import DataLoaderKwargs +from tests.utils import DEFAULT_FACTORS_NAME, prepare_model_and_analyzer, prepare_test + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "conv", + "bert", + ], +) +@pytest.mark.parametrize("eigendecomposition_dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("train_size", [1, 30]) +@pytest.mark.parametrize("seed", [0]) +def test_perform_eigendecomposition( + test_name: str, + eigendecomposition_dtype: torch.dtype, + train_size: int, + seed: int, +) -> None: + # Makes sure that the Eigendecomposition computations are working properly. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + factor_args = FactorArguments( + eigendecomposition_dtype=eigendecomposition_dtype, + ) + analyzer.fit_covariance_matrices( + factors_name=DEFAULT_FACTORS_NAME, + factor_args=factor_args, + dataset=train_dataset, + per_device_batch_size=None, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + analyzer.perform_eigendecomposition( + factors_name=DEFAULT_FACTORS_NAME, + factor_args=factor_args, + overwrite_output_dir=True, + ) + eigen_factors = analyzer.load_eigendecomposition(factors_name=DEFAULT_FACTORS_NAME) + assert set(eigen_factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) + assert len(eigen_factors[ACTIVATION_EIGENVECTORS_NAME]) > 0 + assert len(eigen_factors[GRADIENT_EIGENVECTORS_NAME]) > 0 + for module_name in eigen_factors[ACTIVATION_EIGENVECTORS_NAME]: + assert eigen_factors[ACTIVATION_EIGENVECTORS_NAME][module_name] is not None + assert eigen_factors[GRADIENT_EIGENVECTORS_NAME][module_name] is not None diff --git a/tests/factors/test_eigens.py b/tests/factors/test_lambdas.py similarity index 62% rename from tests/factors/test_eigens.py rename to tests/factors/test_lambdas.py index cb8facb..d2f15e8 100644 --- a/tests/factors/test_eigens.py +++ b/tests/factors/test_lambdas.py @@ -4,11 +4,8 @@ import torch from kronfluence.arguments import FactorArguments -from kronfluence.utils.common.factor_arguments import test_factor_arguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments from kronfluence.utils.constants import ( - ACTIVATION_EIGENVECTORS_NAME, - EIGENDECOMPOSITION_FACTOR_NAMES, - GRADIENT_EIGENVECTORS_NAME, LAMBDA_FACTOR_NAMES, LAMBDA_MATRIX_NAME, NUM_LAMBDA_PROCESSED, @@ -16,8 +13,10 @@ from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( ATOL, + DEFAULT_FACTORS_NAME, RTOL, check_tensor_dict_equivalence, + custom_factors_name, prepare_model_and_analyzer, prepare_test, ) @@ -28,68 +27,11 @@ [ "mlp", "repeated_mlp", - "mlp_checkpoint", "conv", - "conv_bn", - "bert", - "gpt", - ], -) -@pytest.mark.parametrize("eigendecomposition_dtype", [torch.float32, torch.float64]) -@pytest.mark.parametrize("train_size", [16]) -@pytest.mark.parametrize("seed", [0]) -def test_perform_eigendecomposition( - test_name: str, - eigendecomposition_dtype: torch.dtype, - train_size: int, - seed: int, -) -> None: - # Makes sure that the Eigendecomposition computations are working properly. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - factor_args = FactorArguments( - eigendecomposition_dtype=eigendecomposition_dtype, - ) - factors_name = f"pytest_{test_name}_{test_perform_eigendecomposition.__name__}" - analyzer.fit_covariance_matrices( - factors_name=factors_name, - factor_args=factor_args, - dataset=train_dataset, - per_device_batch_size=4, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - analyzer.perform_eigendecomposition( - factors_name=factors_name, - factor_args=factor_args, - overwrite_output_dir=True, - ) - eigen_factors = analyzer.load_eigendecomposition(factors_name=factors_name) - assert set(eigen_factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) - assert len(eigen_factors[ACTIVATION_EIGENVECTORS_NAME]) > 0 - for module_name in eigen_factors[ACTIVATION_EIGENVECTORS_NAME]: - assert eigen_factors[ACTIVATION_EIGENVECTORS_NAME][module_name] is not None - assert eigen_factors[GRADIENT_EIGENVECTORS_NAME][module_name] is not None - - -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "repeated_mlp", - "mlp_checkpoint", - "conv", - "conv_bn", "bert", + "roberta", "gpt", + "gpt_checkpoint", ], ) @pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.bfloat16]) @@ -120,11 +62,10 @@ def test_fit_lambda_matrices( per_sample_gradient_dtype=per_sample_gradient_dtype, ) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True - factors_name = f"pytest_{test_name}_{test_fit_lambda_matrices.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, per_device_batch_size=train_size // 4, factor_args=factor_args, @@ -132,7 +73,7 @@ def test_fit_lambda_matrices( overwrite_output_dir=True, ) - lambda_factors = analyzer.load_lambda_matrices(factors_name=factors_name) + lambda_factors = analyzer.load_lambda_matrices(factors_name=DEFAULT_FACTORS_NAME) assert set(lambda_factors.keys()) == set(LAMBDA_FACTOR_NAMES) assert len(lambda_factors[LAMBDA_MATRIX_NAME]) > 0 for module_name in lambda_factors[LAMBDA_MATRIX_NAME]: @@ -144,7 +85,7 @@ def test_fit_lambda_matrices( [ "mlp", "conv", - "gpt", + "roberta", ], ) @pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) @@ -163,14 +104,15 @@ def test_lambda_matrices_batch_size_equivalence( seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - factor_args = test_factor_arguments(strategy=strategy) + factor_args = pytest_factor_arguments(strategy=strategy) analyzer.fit_all_factors( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_batch_size_equivalence.__name__}_{strategy}_bs1", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, per_device_batch_size=1, factor_args=factor_args, @@ -178,11 +120,11 @@ def test_lambda_matrices_batch_size_equivalence( overwrite_output_dir=True, ) bs1_lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_batch_size_equivalence.__name__}_{strategy}_bs1", + factors_name=DEFAULT_FACTORS_NAME, ) analyzer.fit_all_factors( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_batch_size_equivalence.__name__}_{strategy}_bs8", + factors_name=custom_factors_name("bs8"), dataset=train_dataset, per_device_batch_size=8, factor_args=factor_args, @@ -190,30 +132,39 @@ def test_lambda_matrices_batch_size_equivalence( overwrite_output_dir=True, ) bs8_lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_batch_size_equivalence.__name__}_{strategy}_bs8", + factors_name=custom_factors_name("bs8"), ) for name in LAMBDA_FACTOR_NAMES: assert check_tensor_dict_equivalence(bs1_lambda_factors[name], bs8_lambda_factors[name], atol=ATOL, rtol=RTOL) + analyzer.fit_all_factors( + factors_name=custom_factors_name("auto"), + dataset=train_dataset, + per_device_batch_size=None, + factor_args=factor_args, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + auto_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("auto"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence(bs1_lambda_factors[name], auto_lambda_factors[name], atol=ATOL, rtol=RTOL) -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "conv", - ], -) + +@pytest.mark.parametrize("test_name", ["mlp", "conv_bn"]) @pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) -@pytest.mark.parametrize("data_partition_size", [4]) -@pytest.mark.parametrize("module_partition_size", [3]) +@pytest.mark.parametrize("data_partitions", [2, 4]) +@pytest.mark.parametrize("module_partitions", [2, 3]) @pytest.mark.parametrize("train_size", [81]) @pytest.mark.parametrize("seed", [2]) def test_lambda_matrices_partition_equivalence( test_name: str, strategy: str, - data_partition_size: int, - module_partition_size: int, + data_partitions: int, + module_partitions: int, train_size: int, seed: int, ) -> None: @@ -229,10 +180,9 @@ def test_lambda_matrices_partition_equivalence( task=task, ) - factor_args = test_factor_arguments(strategy=strategy) - factors_name = f"pytest_{test_name}_{strategy}_{test_lambda_matrices_partition_equivalence.__name__}" + factor_args = pytest_factor_arguments(strategy=strategy) analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=8, @@ -240,13 +190,13 @@ def test_lambda_matrices_partition_equivalence( dataloader_kwargs=kwargs, ) lambda_factors = analyzer.load_lambda_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, ) - factor_args.lambda_data_partition_size = data_partition_size - factor_args.lambda_module_partition_size = module_partition_size + factor_args.lambda_data_partitions = data_partitions + factor_args.lambda_module_partitions = module_partitions analyzer.fit_all_factors( - factors_name=f"pytest_{test_name}_{strategy}_{data_partition_size}_{module_partition_size}", + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), dataset=train_dataset, factor_args=factor_args, per_device_batch_size=6, @@ -254,7 +204,7 @@ def test_lambda_matrices_partition_equivalence( dataloader_kwargs=kwargs, ) partitioned_lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_name}_{strategy}_{data_partition_size}_{module_partition_size}", + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), ) for name in LAMBDA_FACTOR_NAMES: assert check_tensor_dict_equivalence( @@ -266,14 +216,14 @@ def test_lambda_matrices_partition_equivalence( "test_name", [ "mlp", - "conv", + "conv_bn", "bert", "gpt", ], ) -@pytest.mark.parametrize("train_size", [63]) +@pytest.mark.parametrize("train_size", [63, 121]) @pytest.mark.parametrize("seed", [3]) -def test_lambda_matrices_iterative_aggregate( +def test_lambda_matrices_iterative_lambda_aggregation( test_name: str, train_size: int, seed: int, @@ -291,11 +241,10 @@ def test_lambda_matrices_iterative_aggregate( task=task, ) - factors_name = f"pytest_{test_name}_{test_lambda_matrices_iterative_aggregate.__name__}" - factor_args = test_factor_arguments() - factor_args.lambda_iterative_aggregate = False + factor_args = pytest_factor_arguments() + factor_args.use_iterative_lambda_aggregation = False analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=8, @@ -303,20 +252,20 @@ def test_lambda_matrices_iterative_aggregate( dataloader_kwargs=kwargs, ) lambda_factors = analyzer.load_lambda_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, ) - factor_args.lambda_iterative_aggregate = True + factor_args.use_iterative_lambda_aggregation = True analyzer.fit_all_factors( - factors_name=factors_name + "_iterative", + factors_name=custom_factors_name("iterative"), dataset=train_dataset, factor_args=factor_args, - per_device_batch_size=4, + per_device_batch_size=16, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) iterative_lambda_factors = analyzer.load_lambda_matrices( - factors_name=factors_name + "_iterative", + factors_name=custom_factors_name("iterative"), ) for name in LAMBDA_FACTOR_NAMES: @@ -325,14 +274,18 @@ def test_lambda_matrices_iterative_aggregate( @pytest.mark.parametrize( "test_name", - ["mlp", "conv"], + ["conv_bn", "gpt"], ) -@pytest.mark.parametrize("data_partition_size", [1, 4]) +@pytest.mark.parametrize("max_examples", [4, 31]) +@pytest.mark.parametrize("data_partitions", [1, 3]) +@pytest.mark.parametrize("module_partitions", [1, 2]) @pytest.mark.parametrize("train_size", [82]) @pytest.mark.parametrize("seed", [4]) def test_lambda_matrices_max_examples( test_name: str, - data_partition_size: int, + max_examples: int, + data_partitions: int, + module_partitions: int, train_size: int, seed: int, ) -> None: @@ -348,13 +301,13 @@ def test_lambda_matrices_max_examples( task=task, ) - MAX_EXAMPLES = 33 factor_args = FactorArguments( - use_empirical_fisher=True, lambda_max_examples=MAX_EXAMPLES, lambda_data_partition_size=data_partition_size + lambda_max_examples=max_examples, + lambda_data_partitions=data_partitions, + lambda_module_partitions=module_partitions, ) - factors_name = f"pytest_{test_name}_{test_lambda_matrices_max_examples.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, factor_args=factor_args, per_device_batch_size=8, @@ -362,176 +315,188 @@ def test_lambda_matrices_max_examples( dataloader_kwargs=kwargs, ) lambda_factors = analyzer.load_lambda_matrices( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, ) for num_examples in lambda_factors[NUM_LAMBDA_PROCESSED].values(): - assert num_examples == MAX_EXAMPLES + assert num_examples == max_examples @pytest.mark.parametrize( "test_name", [ "mlp", - "conv", + "gpt", ], ) -@pytest.mark.parametrize("train_size", [100]) -@pytest.mark.parametrize("seed", [8]) -def test_lambda_matrices_amp( +@pytest.mark.parametrize("train_size", [105]) +@pytest.mark.parametrize("seed", [6]) +def test_lambda_matrices_gradient_checkpoint( test_name: str, train_size: int, seed: int, ) -> None: - # Lambda matrices should be similar when AMP is enabled. + # Lambda matrices should be the same even when gradient checkpointing is used. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, seed=seed, ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_all_factors( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=8, + per_device_batch_size=5, overwrite_output_dir=True, + factor_args=factor_args, dataloader_kwargs=kwargs, ) lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}" + factors_name=DEFAULT_FACTORS_NAME, ) - factor_args.amp_dtype = torch.float16 + model, _, _, _, task = prepare_test( + test_name=test_name + "_checkpoint", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) analyzer.fit_all_factors( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}_amp", + factors_name=custom_factors_name("cp"), dataset=train_dataset, - per_device_batch_size=8, + per_device_batch_size=6, overwrite_output_dir=True, factor_args=factor_args, dataloader_kwargs=kwargs, ) - amp_lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}_amp", + checkpoint_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("cp"), ) for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence(lambda_factors[name], amp_lambda_factors[name], atol=1e-01, rtol=1e-02) + assert check_tensor_dict_equivalence( + lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL + ) +@pytest.mark.parametrize( + "test_name", + ["mlp", "conv", "gpt"], +) @pytest.mark.parametrize("train_size", [105]) -@pytest.mark.parametrize("seed", [12]) -def test_lambda_matrices_gradient_checkpoint( +@pytest.mark.parametrize("seed", [7]) +def test_lambda_matrices_shared_parameters( + test_name: str, train_size: int, seed: int, ) -> None: - # Lambda matrices should be the same even when gradient checkpointing is used. + # When there are no shared parameters, results with and without `has_shared_parameters` should + # produce the same results. model, train_dataset, _, data_collator, task = prepare_test( - test_name="mlp", + test_name=test_name, train_size=train_size, seed=seed, ) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_all_factors( - factors_name=f"pytest_{test_lambda_matrices_gradient_checkpoint.__name__}", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, per_device_batch_size=5, overwrite_output_dir=True, factor_args=factor_args, + dataloader_kwargs=kwargs, ) lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_lambda_matrices_gradient_checkpoint.__name__}", + factors_name=DEFAULT_FACTORS_NAME, ) - model, _, _, _, task = prepare_test( - test_name="mlp_checkpoint", - train_size=train_size, - seed=seed, - ) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) + factor_args.has_shared_parameters = True analyzer.fit_all_factors( - factors_name=f"pytest_{test_lambda_matrices_gradient_checkpoint.__name__}_cp", + factors_name=custom_factors_name("shared"), dataset=train_dataset, per_device_batch_size=6, overwrite_output_dir=True, factor_args=factor_args, + dataloader_kwargs=kwargs, ) - checkpoint_lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_lambda_matrices_gradient_checkpoint.__name__}_cp", + shared_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("shared"), ) for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL - ) + assert check_tensor_dict_equivalence(lambda_factors[name], shared_lambda_factors[name], atol=ATOL, rtol=RTOL) -@pytest.mark.parametrize("train_size", [105]) -@pytest.mark.parametrize("seed", [12]) -def test_lambda_matrices_shared_parameters( +@pytest.mark.parametrize("train_size", [121]) +@pytest.mark.parametrize("seed", [8]) +def test_lambda_matrices_inplace( train_size: int, seed: int, ) -> None: - # When there are no shared parameters, results with and without `shared_parameters_exist` should - # produce the same results. + # Lambda matrices should be the identical for with and without in-place ReLU. model, train_dataset, _, data_collator, task = prepare_test( - test_name="mlp", + test_name="conv", train_size=train_size, seed=seed, ) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_all_factors( - factors_name=f"pytest_{test_lambda_matrices_shared_parameters.__name__}", + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, per_device_batch_size=5, overwrite_output_dir=True, factor_args=factor_args, + dataloader_kwargs=kwargs, ) lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_lambda_matrices_shared_parameters.__name__}", + factors_name=DEFAULT_FACTORS_NAME, ) - model, train_dataset, _, _, task = prepare_test( - test_name="mlp", + model, _, _, _, task = prepare_test( + test_name="conv_inplace", train_size=train_size, seed=seed, ) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) - - factor_args.shared_parameters_exist = True analyzer.fit_all_factors( - factors_name=f"pytest_{test_lambda_matrices_shared_parameters.__name__}_shared", + factors_name=custom_factors_name("inplace"), dataset=train_dataset, per_device_batch_size=6, overwrite_output_dir=True, factor_args=factor_args, + dataloader_kwargs=kwargs, ) - checkpoint_lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_lambda_matrices_shared_parameters.__name__}_shared", + inplace_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("inplace"), ) for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL - ) + assert check_tensor_dict_equivalence(lambda_factors[name], inplace_lambda_factors[name], atol=ATOL, rtol=RTOL) diff --git a/tests/gpu_tests/README.md b/tests/gpu_tests/README.md index dee6108..e83f15d 100644 --- a/tests/gpu_tests/README.md +++ b/tests/gpu_tests/README.md @@ -17,18 +17,18 @@ python cpu_test.py ### DDP Tests -To test if running with Distributed Data Parallel (DDP) with 3 GPUs obtains the same result, run: +To test if running with Distributed Data Parallel (DDP) with 4 GPUs obtains the same result, run: ```bash -torchrun --nnodes=1 --nproc_per_node=3 ddp_test.py +torchrun --nnodes=1 --nproc_per_node=4 ddp_test.py ``` ### FSDP Tests -To test if running with Fully Sharded Data Parallel (FSDP) with 3 GPUs obtains the same result, run: +To test if running with Fully Sharded Data Parallel (FSDP) with 4 GPUs obtains the same result, run: ```bash -torchrun --nnodes=1 --nproc_per_node=3 fsdp_test.py +torchrun --nnodes=1 --nproc_per_node=4 fsdp_test.py ``` ### torch.compile Tests @@ -49,7 +49,7 @@ python amp_test.py ### CPU Offload Test -To test if `cached_activation_cpu_offload` option is properly implemented, run: +To test if `offload_activations_to_cpu` option is properly implemented, run: ```bash pytest test_offload_cpu.py diff --git a/tests/gpu_tests/amp_test.py b/tests/gpu_tests/amp_test.py index 8f2ad17..0b7a58e 100644 --- a/tests/gpu_tests/amp_test.py +++ b/tests/gpu_tests/amp_test.py @@ -7,7 +7,8 @@ from torch.utils import data from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -45,13 +46,8 @@ def setUpClass(cls) -> None: def test_covariance_matrices(self) -> None: covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - amp_dtype=torch.float16, - ) + factor_args = pytest_factor_arguments() + factor_args.amp_dtype = torch.bfloat16 self.analyzer.fit_covariance_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -73,15 +69,10 @@ def test_covariance_matrices(self) -> None: rtol=1e-1, ) - def test_lambda_matrices(self): + def test_lambda_matrices(self) -> None: lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - amp_dtype=torch.float16, - ) + factor_args = pytest_factor_arguments() + factor_args.amp_dtype = torch.bfloat16 self.analyzer.fit_lambda_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -107,12 +98,8 @@ def test_lambda_matrices(self): def test_pairwise_scores(self) -> None: pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - amp_dtype=torch.float16, - ) + score_args = pytest_score_arguments() + score_args.amp_dtype = torch.bfloat16 self.analyzer.compute_pairwise_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -134,17 +121,13 @@ def test_pairwise_scores(self) -> None: assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-5, - rtol=1e-3, + atol=1e-3, + rtol=1e-1, ) def test_self_scores(self) -> None: - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - amp_dtype=torch.float16, - ) + score_args = pytest_score_arguments() + score_args.amp_dtype = torch.bfloat16 self.analyzer.compute_self_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -155,7 +138,6 @@ def test_self_scores(self) -> None: overwrite_output_dir=True, ) new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) - self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") @@ -164,8 +146,8 @@ def test_self_scores(self) -> None: assert check_tensor_dict_equivalence( self_scores, new_self_scores, - atol=1e-5, - rtol=1e-3, + atol=1e-3, + rtol=1e-1, ) diff --git a/tests/gpu_tests/compile_test.py b/tests/gpu_tests/compile_test.py index a997ddf..10b3a6e 100644 --- a/tests/gpu_tests/compile_test.py +++ b/tests/gpu_tests/compile_test.py @@ -7,7 +7,8 @@ from torch.utils import data from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -15,7 +16,7 @@ ) from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence logging.basicConfig(level=logging.DEBUG) OLD_FACTOR_NAME = "single_gpu" @@ -41,6 +42,9 @@ def setUpClass(cls) -> None: cls.model = cls.model.cuda() cls.model = torch.compile(cls.model) + print(cls.model) + print(list(cls.model.named_modules())) + cls.analyzer = Analyzer( analysis_name="gpu_test", model=cls.model, @@ -49,12 +53,7 @@ def setUpClass(cls) -> None: def test_covariance_matrices(self) -> None: covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_covariance_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -69,15 +68,42 @@ def test_covariance_matrices(self) -> None: print(f"Name: {name, module_name}") print(f"Previous factor: {covariance_factors[name][module_name]}") print(f"New factor: {new_covariance_factors[name][module_name]}") + assert check_tensor_dict_equivalence( + covariance_factors[name], + new_covariance_factors[name], + atol=ATOL, + rtol=RTOL, + ) - def test_lambda_matrices(self): + def test_lambda_matrices(self) -> None: lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, + factor_args = pytest_factor_arguments() + self.analyzer.fit_lambda_matrices( + factors_name=NEW_FACTOR_NAME, + dataset=self.train_dataset, + factor_args=factor_args, + per_device_batch_size=512, + overwrite_output_dir=True, + load_from_factors_name=OLD_FACTOR_NAME, ) + new_lambda_factors = self.analyzer.load_lambda_matrices(factors_name=NEW_FACTOR_NAME) + + for name in LAMBDA_FACTOR_NAMES: + for module_name in lambda_factors[name]: + print(f"Name: {name, module_name}") + print(f"Previous factor: {lambda_factors[name][module_name]}") + print(f"New factor: {new_lambda_factors[name][module_name]}") + assert check_tensor_dict_equivalence( + lambda_factors[name], + new_lambda_factors[name], + atol=ATOL, + rtol=RTOL, + ) + + def test_lambda_shared_matrices(self) -> None: + lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) + factor_args = pytest_factor_arguments() + factor_args.has_shared_parameters = True self.analyzer.fit_lambda_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -96,18 +122,12 @@ def test_lambda_matrices(self): assert check_tensor_dict_equivalence( lambda_factors[name], new_lambda_factors[name], - atol=1e-3, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) def test_pairwise_scores(self) -> None: - pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_pairwise_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -121,6 +141,7 @@ def test_pairwise_scores(self) -> None: overwrite_output_dir=True, ) new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][10]}") print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") @@ -129,16 +150,12 @@ def test_pairwise_scores(self) -> None: assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) def test_self_scores(self) -> None: - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_self_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -149,8 +166,8 @@ def test_self_scores(self) -> None: overwrite_output_dir=True, ) new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) - self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) + print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") @@ -158,8 +175,8 @@ def test_self_scores(self) -> None: assert check_tensor_dict_equivalence( self_scores, new_self_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) diff --git a/tests/gpu_tests/cpu_test.py b/tests/gpu_tests/cpu_test.py index 455327d..29c3f13 100644 --- a/tests/gpu_tests/cpu_test.py +++ b/tests/gpu_tests/cpu_test.py @@ -7,7 +7,8 @@ from torch.utils import data from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -15,7 +16,7 @@ ) from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence logging.basicConfig(level=logging.DEBUG) OLD_FACTOR_NAME = "single_gpu" @@ -43,12 +44,7 @@ def setUpClass(cls) -> None: def test_covariance_matrices(self) -> None: covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_covariance_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -66,18 +62,13 @@ def test_covariance_matrices(self) -> None: assert check_tensor_dict_equivalence( covariance_factors[name], new_covariance_factors[name], - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) - def test_lambda_matrices(self): + def test_lambda_matrices(self) -> None: lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_lambda_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -96,18 +87,14 @@ def test_lambda_matrices(self): assert check_tensor_dict_equivalence( lambda_factors[name], new_lambda_factors[name], - atol=1e-3, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) def test_pairwise_scores(self) -> None: pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_pairwise_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -129,16 +116,12 @@ def test_pairwise_scores(self) -> None: assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) def test_self_scores(self) -> None: - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_self_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -158,8 +141,34 @@ def test_self_scores(self) -> None: assert check_tensor_dict_equivalence( self_scores, new_self_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, + ) + + def test_self_scores_with_measurement(self) -> None: + score_args = pytest_score_arguments() + score_args.use_measurement_for_self_influence = True + self.analyzer.compute_self_scores( + scores_name=NEW_SCORE_NAME + "_measurement", + factors_name=OLD_FACTOR_NAME, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME + "_measurement") + + self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME + "_measurement") + print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") + print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") + print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + self_scores, + new_self_scores, + atol=ATOL, + rtol=RTOL, ) diff --git a/tests/gpu_tests/ddp_test.py b/tests/gpu_tests/ddp_test.py index a05d569..7ad3e63 100644 --- a/tests/gpu_tests/ddp_test.py +++ b/tests/gpu_tests/ddp_test.py @@ -9,7 +9,8 @@ from torch.utils import data from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -18,12 +19,12 @@ from kronfluence.utils.model import apply_ddp from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence LOCAL_RANK = int(os.environ["LOCAL_RANK"]) WORLD_RANK = int(os.environ["RANK"]) WORLD_SIZE = int(os.environ["WORLD_SIZE"]) -logging.basicConfig(level=logging.DEBUG) +logging.basicConfig(level=logging.INFO) OLD_FACTOR_NAME = "single_gpu" NEW_FACTOR_NAME = "ddp" OLD_SCORE_NAME = "single_gpu" @@ -60,12 +61,7 @@ def setUpClass(cls) -> None: def test_covariance_matrices(self) -> None: covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_covariance_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -85,18 +81,42 @@ def test_covariance_matrices(self) -> None: assert check_tensor_dict_equivalence( covariance_factors[name], new_covariance_factors[name], - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) - def test_lambda_matrices(self): + def test_lambda_matrices(self) -> None: lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, + factor_args = pytest_factor_arguments() + self.analyzer.fit_lambda_matrices( + factors_name=NEW_FACTOR_NAME, + dataset=self.train_dataset, + factor_args=factor_args, + per_device_batch_size=512, + overwrite_output_dir=True, + load_from_factors_name=OLD_FACTOR_NAME, ) + new_lambda_factors = self.analyzer.load_lambda_matrices(factors_name=NEW_FACTOR_NAME) + + for name in LAMBDA_FACTOR_NAMES: + if LOCAL_RANK == 0: + for module_name in lambda_factors[name]: + print(f"Name: {name, module_name}") + print(f"Previous factor: {lambda_factors[name][module_name]}") + print(f"New factor: {new_lambda_factors[name][module_name]}") + if LOCAL_RANK == 0: + assert check_tensor_dict_equivalence( + lambda_factors[name], + new_lambda_factors[name], + atol=ATOL, + rtol=RTOL, + ) + + def test_lambda_partition_matrices(self) -> None: + lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) + factor_args = pytest_factor_arguments() + factor_args.lambda_module_partitions = 2 + factor_args.lambda_data_partitions = 2 self.analyzer.fit_lambda_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -117,18 +137,14 @@ def test_lambda_matrices(self): assert check_tensor_dict_equivalence( lambda_factors[name], new_lambda_factors[name], - atol=1e-3, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) def test_pairwise_scores(self) -> None: pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_pairwise_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -155,18 +171,48 @@ def test_pairwise_scores(self) -> None: assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) - def test_self_scores(self) -> None: - self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) + def test_pairwise_partition_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, + score_args = pytest_score_arguments() + score_args.module_partitions = 2 + score_args.data_partitions = 2 + self.analyzer.compute_pairwise_scores( + scores_name=NEW_SCORE_NAME, + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_self_scores(self) -> None: + score_args = pytest_score_arguments() self.analyzer.compute_self_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -177,6 +223,7 @@ def test_self_scores(self) -> None: overwrite_output_dir=True, ) new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) + self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) if LOCAL_RANK == 0: print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") @@ -186,20 +233,38 @@ def test_self_scores(self) -> None: assert check_tensor_dict_equivalence( self_scores, new_self_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) - def test_lr_pairwise_scores(self) -> None: - pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") - - score_args = ScoreArguments( - query_gradient_rank=32, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - query_gradient_svd_dtype=torch.float64, + score_args.use_measurement_for_self_influence = True + self.analyzer.compute_self_scores( + scores_name=NEW_SCORE_NAME + "_measurement", + factors_name=OLD_FACTOR_NAME, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, ) + new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME + "_measurement") + self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME + "_measurement") + + if LOCAL_RANK == 0: + print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") + print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") + print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + self_scores, + new_self_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_lr_pairwise_scores(self) -> None: + score_args = pytest_score_arguments() + score_args.query_gradient_low_rank = 32 self.analyzer.compute_pairwise_scores( scores_name="ddp_qb", factors_name=OLD_FACTOR_NAME, @@ -212,77 +277,85 @@ def test_lr_pairwise_scores(self) -> None: score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") + + def test_per_module_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) + score_args = pytest_score_arguments() + score_args.compute_per_module_scores = True + self.analyzer.compute_pairwise_scores( + scores_name=NEW_SCORE_NAME + "_per_module", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME + "_per_module") if LOCAL_RANK == 0: - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") - print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") - print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + total_scores = None + for module_name in new_pairwise_scores: + if total_scores is None: + total_scores = new_pairwise_scores[module_name] + else: + total_scores.add_(new_pairwise_scores[module_name]) assert check_tensor_dict_equivalence( pairwise_scores, - new_pairwise_scores, - atol=1e-3, - rtol=1e-1, + {ALL_MODULE_NAME: total_scores}, + atol=ATOL, + rtol=RTOL, ) - def test_per_module_pairwise_scores(self) -> None: - pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") - - score_args = ScoreArguments( - per_module_score=True, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - query_gradient_svd_dtype=torch.float64, + def test_lr_accumulate_pairwise_scores(self) -> None: + score_args = pytest_score_arguments() + score_args.query_gradient_low_rank = 32 + score_args.query_gradient_accumulation_steps = 3 + self.analyzer.compute_pairwise_scores( + scores_name="ddp_qb_agg", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=2, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, ) + + def test_aggregate_scores(self) -> None: + score_args = pytest_score_arguments() + score_args.aggregate_train_gradients = True self.analyzer.compute_pairwise_scores( - scores_name=NEW_SCORE_NAME + "_per_module", + scores_name="ddp", factors_name=OLD_FACTOR_NAME, query_dataset=self.eval_dataset, train_dataset=self.train_dataset, train_indices=list(range(TRAIN_INDICES)), query_indices=list(range(QUERY_INDICES)), - per_device_query_batch_size=12, + per_device_query_batch_size=2, per_device_train_batch_size=512, score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp") + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_train_agg") if LOCAL_RANK == 0: - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") - print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") - print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-3, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) - def test_lr_accumulate_pairwise_scores(self) -> None: - pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") - - score_args = ScoreArguments( - query_gradient_rank=32, - num_query_gradient_accumulations=3, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - query_gradient_svd_dtype=torch.float64, - ) + score_args.aggregate_query_gradients = True self.analyzer.compute_pairwise_scores( - scores_name="ddp_qb_agg", + scores_name="ddp", factors_name=OLD_FACTOR_NAME, query_dataset=self.eval_dataset, train_dataset=self.train_dataset, @@ -293,20 +366,15 @@ def test_lr_accumulate_pairwise_scores(self) -> None: score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb_agg") + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp") + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_all_agg") if LOCAL_RANK == 0: - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][44]}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][44]}") assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-1, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) @classmethod diff --git a/tests/gpu_tests/fsdp_test.py b/tests/gpu_tests/fsdp_test.py index 8c08975..43e627b 100644 --- a/tests/gpu_tests/fsdp_test.py +++ b/tests/gpu_tests/fsdp_test.py @@ -9,7 +9,8 @@ from torch.utils import data from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -18,7 +19,7 @@ from kronfluence.utils.model import apply_fsdp from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence LOCAL_RANK = int(os.environ["LOCAL_RANK"]) WORLD_RANK = int(os.environ["RANK"]) @@ -62,12 +63,7 @@ def setUpClass(cls) -> None: def test_covariance_matrices(self) -> None: covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_covariance_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -87,18 +83,13 @@ def test_covariance_matrices(self) -> None: assert check_tensor_dict_equivalence( covariance_factors[name], new_covariance_factors[name], - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) - def test_lambda_matrices(self): + def test_lambda_matrices(self) -> None: lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_lambda_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -119,18 +110,14 @@ def test_lambda_matrices(self): assert check_tensor_dict_equivalence( lambda_factors[name], new_lambda_factors[name], - atol=1e-3, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) def test_pairwise_scores(self) -> None: pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_pairwise_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -153,18 +140,14 @@ def test_pairwise_scores(self) -> None: assert check_tensor_dict_equivalence( pairwise_scores, new_pairwise_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) def test_self_scores(self) -> None: self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() self.analyzer.compute_self_scores( scores_name=NEW_SCORE_NAME, factors_name=OLD_FACTOR_NAME, @@ -184,8 +167,8 @@ def test_self_scores(self) -> None: assert check_tensor_dict_equivalence( self_scores, new_self_scores, - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) @classmethod diff --git a/tests/gpu_tests/pipeline.py b/tests/gpu_tests/pipeline.py index 117ec52..76f542d 100644 --- a/tests/gpu_tests/pipeline.py +++ b/tests/gpu_tests/pipeline.py @@ -25,12 +25,12 @@ def compute_train_loss( if not sample: return F.cross_entropy(logits, labels, reduction="sum") with torch.no_grad(): - probs = torch.nn.functional.softmax(logits, dim=-1) + probs = torch.nn.functional.softmax(logits.detach(), dim=-1) sampled_labels = torch.multinomial( probs, num_samples=1, ).flatten() - return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum") + return F.cross_entropy(logits, sampled_labels, reduction="sum") def compute_measurement( self, diff --git a/tests/gpu_tests/prepare_tests.py b/tests/gpu_tests/prepare_tests.py index 1392edf..a9a140b 100644 --- a/tests/gpu_tests/prepare_tests.py +++ b/tests/gpu_tests/prepare_tests.py @@ -6,7 +6,8 @@ from tqdm import tqdm from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset # Pick difficult cases where the dataset is not perfectly divisible by batch size. @@ -92,12 +93,7 @@ def run_analysis() -> None: task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() analyzer.fit_all_factors( factors_name="single_gpu", dataset=train_dataset, @@ -106,11 +102,7 @@ def run_analysis() -> None: overwrite_output_dir=True, ) - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = pytest_score_arguments() analyzer.compute_pairwise_scores( scores_name="single_gpu", factors_name="single_gpu", @@ -130,13 +122,19 @@ def run_analysis() -> None: overwrite_output_dir=True, ) - score_args = ScoreArguments( - query_gradient_rank=32, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - query_gradient_svd_dtype=torch.float64, + score_args = pytest_score_arguments() + score_args.use_measurement_for_self_influence = True + analyzer.compute_self_scores( + scores_name="single_gpu_measurement", + factors_name="single_gpu", + train_dataset=train_dataset, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, ) + + score_args = pytest_score_arguments() + score_args.query_gradient_low_rank = 32 analyzer.compute_pairwise_scores( scores_name="single_gpu_qb", factors_name="single_gpu", @@ -148,6 +146,46 @@ def run_analysis() -> None: overwrite_output_dir=True, ) + score_args = pytest_score_arguments() + score_args.aggregate_train_gradients = True + analyzer.compute_pairwise_scores( + scores_name="single_gpu_train_agg", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.aggregate_query_gradients = True + analyzer.compute_pairwise_scores( + scores_name="single_gpu_query_agg", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.aggregate_train_gradients = True + score_args.aggregate_query_gradients = True + analyzer.compute_pairwise_scores( + scores_name="single_gpu_all_agg", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + if __name__ == "__main__": train() diff --git a/tests/gpu_tests/test_offload_cpu.py b/tests/gpu_tests/test_offload_cpu.py index 0066270..fca8ef7 100644 --- a/tests/gpu_tests/test_offload_cpu.py +++ b/tests/gpu_tests/test_offload_cpu.py @@ -6,6 +6,8 @@ from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test @@ -23,13 +25,13 @@ "gpt", ], ) -@pytest.mark.parametrize("cached_activation_cpu_offload", [True, False]) +@pytest.mark.parametrize("offload_activations_to_cpu", [True, False]) @pytest.mark.parametrize("query_size", [16]) @pytest.mark.parametrize("train_size", [32]) @pytest.mark.parametrize("seed", [1]) def test_cpu_offloads( test_name: str, - cached_activation_cpu_offload: bool, + offload_activations_to_cpu: bool, query_size: int, train_size: int, seed: int, @@ -50,10 +52,10 @@ def test_cpu_offloads( disable_tqdm=True, ) factor_args = FactorArguments( - cached_activation_cpu_offload=cached_activation_cpu_offload, + offload_activations_to_cpu=offload_activations_to_cpu, ) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True factors_name = f"pytest_{test_name}_{test_cpu_offloads.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -65,7 +67,7 @@ def test_cpu_offloads( ) score_args = ScoreArguments( - cached_activation_cpu_offload=cached_activation_cpu_offload, + offload_activations_to_cpu=offload_activations_to_cpu, ) scores_name = f"pytest_{test_name}_{test_cpu_offloads.__name__}_scores" analyzer.compute_pairwise_scores( @@ -122,15 +124,9 @@ def test_cpu_offloads_identical( disable_model_save=True, disable_tqdm=True, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - cached_activation_cpu_offload=False, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True factors_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -140,13 +136,7 @@ def test_cpu_offloads_identical( factor_args=factor_args, overwrite_output_dir=True, ) - score_args = ScoreArguments( - cached_activation_cpu_offload=False, - per_sample_gradient_dtype=torch.float64, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_module_score=per_module_score, - ) + score_args = pytest_score_arguments() scores_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_scores" analyzer.compute_pairwise_scores( scores_name=scores_name, @@ -162,7 +152,7 @@ def test_cpu_offloads_identical( pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) factors_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_cached" - factor_args.cached_activation_cpu_offload = True + factor_args.offload_activations_to_cpu = True analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -171,7 +161,7 @@ def test_cpu_offloads_identical( factor_args=factor_args, overwrite_output_dir=True, ) - score_args.cached_activation_cpu_offload = True + score_args.offload_activations_to_cpu = True scores_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_cached_scores" analyzer.compute_pairwise_scores( scores_name=scores_name, diff --git a/tests/modules/test_matmul.py b/tests/modules/test_matmul.py new file mode 100644 index 0000000..aceb943 --- /dev/null +++ b/tests/modules/test_matmul.py @@ -0,0 +1,237 @@ +import time + +import opt_einsum +import pytest +import torch +from accelerate.utils import set_seed +from opt_einsum import DynamicProgramming + + +def test_query_gradient_svd( + seed: int = 0, +) -> None: + input_dim = 2048 + output_dim = 1024 + batch_dim = 16 + set_seed(seed) + + gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) + + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + assert torch.allclose(gradient, U @ torch.diag_embed(S) @ V, atol=1e-5, rtol=1e-3) + + rank = 32 + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + assert torch.bmm(left, right).shape == gradient.shape + + rank = input_dim + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + assert torch.allclose(torch.bmm(left, right), gradient, atol=1e-5, rtol=1e-3) + + rank = 32 + lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank), dtype=torch.float64) + lr_gradient2 = torch.rand(size=(batch_dim, rank, input_dim), dtype=torch.float64) + gradient = torch.bmm(lr_gradient1, lr_gradient2) + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + assert torch.allclose(torch.bmm(left_mat, right_mat), gradient, atol=1e-5, rtol=1e-3) + + query_batch_dim = 32 + new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) + score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) + + lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + assert torch.allclose(score, lr_score) + + lr_score_reconst_matmul = torch.matmul( + torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() + ) + assert torch.allclose(score, lr_score_reconst_matmul) + + # These should be able to avoid explicit reconstruction. This should be used when input_dim > output_dim. + intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) + final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) + assert torch.allclose(score, final) + print("Option 1") + print(intermediate.numel()) + + # This should be used when output_dim > input_dim. + intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) + final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) + assert torch.allclose(score, final2) + print("Option 2") + print(intermediate2.numel()) + + print("Reconstruction") + print((torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel()) + path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + print(path) + + +@pytest.mark.parametrize("input_dim", [256, 512]) +@pytest.mark.parametrize("output_dim", [512, 1024]) +@pytest.mark.parametrize("batch_dim", [8, 16]) +@pytest.mark.parametrize("qbatch_dim", [8, 16]) +@pytest.mark.parametrize("rank", [32]) +@pytest.mark.parametrize("seed", [0]) +def test_query_gradient_svd_reconst( + input_dim: int, + output_dim: int, + batch_dim: int, + qbatch_dim: int, + rank: int, + seed: int, +) -> None: + set_seed(seed) + + lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank + 50), dtype=torch.float64) + lr_gradient2 = torch.rand(size=(batch_dim, rank + 50, input_dim), dtype=torch.float64) + gradient = torch.bmm(lr_gradient1, lr_gradient2) + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + new_gradient = torch.rand(size=(qbatch_dim, output_dim, input_dim), dtype=torch.float64) + + lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + lr_score_reconst_matmul = torch.matmul( + torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() + ) + assert torch.allclose(lr_score, lr_score_reconst_matmul) + + # This should be used when input_dim > output_dim. + intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) + final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) + assert torch.allclose(lr_score, final) + print("Option 1") + print(intermediate.numel()) + + # This should be used when output_dim > input_dim. + intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) + final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) + assert torch.allclose(lr_score, final2) + print("Option 2") + print(intermediate2.numel()) + + print("Reconstruction") + reconst_numel = (torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel() + print(reconst_numel) + path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + print(path) + + if new_gradient.size(0) * right_mat.size(0) * rank * min((right_mat.size(2), left_mat.size(1))) > right_mat.size( + 0 + ) * right_mat.size(2) * left_mat.size(1): + assert intermediate2.numel() > reconst_numel and intermediate.numel() > reconst_numel + elif output_dim >= input_dim: + assert intermediate2.numel() <= reconst_numel + else: + assert intermediate.numel() <= reconst_numel + + +def test_compute_score_matmul( + seed: int = 0, +) -> None: + input_dim = 4096 + output_dim = 100 + token_dim = 1 + batch_dim = 1024 + query_batch_dim = 2 + set_seed(seed) + + input_activation = torch.rand(size=(batch_dim, token_dim, input_dim), dtype=torch.float64) + output_gradient = torch.rand(size=(batch_dim, token_dim, output_dim), dtype=torch.float64) + gradient = opt_einsum.contract("b...i,b...o->bio", output_gradient, input_activation) + new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) + + score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) + path = opt_einsum.contract_path("toi,qoi->tq", gradient, new_gradient) + print(path) + + unsqueeze_score = opt_einsum.contract("t...,q...->tq", gradient, new_gradient) + assert torch.allclose(score, unsqueeze_score) + + path = opt_einsum.contract_path( + "bti,bto,qio->qb", + output_gradient, + input_activation, + new_gradient, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + ) + print(path) + + +def test_precondition_gradient( + seed: int = 0, +) -> None: + input_dim = 128 + output_dim = 256 + batch_dim = 8 + lambda_scale = 1000 + damping = 1e-08 + + set_seed(seed) + A = torch.rand(size=(input_dim, input_dim), dtype=torch.float64) + B = torch.rand(size=(output_dim, output_dim), dtype=torch.float64) + Lambda = torch.rand(size=(output_dim, input_dim), dtype=torch.float64) + gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) + + start_time = time.time() + rotated_gradient = torch.einsum( + "ij,bjl,lk->bik", + ( + B.t(), + gradient, + A, + ), + ) + rotated_gradient.div_(Lambda + damping) + results = lambda_scale * torch.einsum( + "ij,bjl,lk->bik", + (B, rotated_gradient, A.t()), + ) + print(f"Took {time.time() - start_time} seconds.") + + start_time = time.time() + grads_rot = torch.matmul( + B.t(), + torch.matmul( + gradient, + A, + ), + ) + scaled_lambda = Lambda / lambda_scale + grads_rot.div_(scaled_lambda) + raw_results = torch.matmul( + B, + torch.matmul( + grads_rot, + A.t(), + ), + ) + print(f"Took {time.time() - start_time} seconds.") + + assert torch.allclose(raw_results, results, atol=1e-5, rtol=1e-3) diff --git a/tests/modules/test_modules.py b/tests/modules/test_modules.py index 3dc2f3e..7e60489 100644 --- a/tests/modules/test_modules.py +++ b/tests/modules/test_modules.py @@ -9,13 +9,12 @@ from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import set_mode, wrap_tracked_modules -from kronfluence.utils.save import verify_models_equivalence from tests.utils import prepare_test @pytest.mark.parametrize( "test_name", - ["mlp", "conv", "conv_bn", "bert", "gpt"], + ["mlp", "conv_bn", "gpt"], ) @pytest.mark.parametrize( "mode", @@ -24,6 +23,7 @@ ModuleMode.COVARIANCE, ModuleMode.LAMBDA, ModuleMode.PRECONDITION_GRADIENT, + ModuleMode.GRADIENT_AGGREGATION, ], ) @pytest.mark.parametrize("train_size", [32]) @@ -34,6 +34,7 @@ def test_tracked_modules_forward_equivalence( train_size: int, seed: int, ) -> None: + # The forward pass should produce the same results with and without wrapped modules. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -82,6 +83,7 @@ def test_tracked_modules_forward_equivalence( ModuleMode.COVARIANCE, ModuleMode.LAMBDA, ModuleMode.PRECONDITION_GRADIENT, + ModuleMode.GRADIENT_AGGREGATION, ], ) @pytest.mark.parametrize("train_size", [32]) @@ -92,6 +94,7 @@ def test_tracked_modules_backward_equivalence( train_size: int, seed: int, ) -> None: + # The backward pass should produce the same results with and without wrapped modules. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -125,30 +128,10 @@ def test_tracked_modules_backward_equivalence( wrapped_loss = task.compute_train_loss(batch, wrapped_model, sample=False) wrapped_loss.backward() for name, param in wrapped_model.named_parameters(): - wrapped_grads[name] = param.grad.detach() + if param.grad is not None: + wrapped_grads[name] = param.grad.detach() for name, grad in wrapped_grads.items(): original_name = name.replace(".original_module", "") if original_name in original_grads: assert torch.allclose(grad, original_grads[original_name]) - - -def test_verify_models_equivalence() -> None: - model1, _, _, _, _ = prepare_test( - test_name="mlp", - train_size=10, - seed=0, - ) - model2, _, _, _, _ = prepare_test( - test_name="mlp", - train_size=10, - seed=1, - ) - model3, _, _, _, _ = prepare_test( - test_name="conv", - train_size=10, - seed=1, - ) - assert verify_models_equivalence(model1.state_dict(), model1.state_dict()) - assert not verify_models_equivalence(model1.state_dict(), model2.state_dict()) - assert not verify_models_equivalence(model1.state_dict(), model3.state_dict()) diff --git a/tests/modules/test_per_sample_gradients.py b/tests/modules/test_per_sample_gradients.py index b18e6d9..84a9481 100644 --- a/tests/modules/test_per_sample_gradients.py +++ b/tests/modules/test_per_sample_gradients.py @@ -4,7 +4,6 @@ import time from typing import Any, Dict, List -import opt_einsum import pytest import torch from accelerate.utils import find_batch_size, set_seed @@ -15,7 +14,8 @@ from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode, TrackedModule from kronfluence.module.utils import ( - finalize_preconditioned_gradient, + finalize_iteration, + get_tracked_module_names, set_mode, update_factor_args, ) @@ -99,6 +99,7 @@ def for_loop_per_sample_gradient( "conv", "conv_bn", "bert", + "roberta", "gpt", ], ) @@ -118,7 +119,7 @@ def test_for_loop_per_sample_gradient_equivalence( ) original_model = copy.deepcopy(model) - batch_size = 4 + batch_size = 3 num_batches = train_size // batch_size train_loader = DataLoader( train_dataset, @@ -138,8 +139,9 @@ def test_for_loop_per_sample_gradient_equivalence( strategy="identity", ) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True update_factor_args(model=model, factor_args=factor_args) + tracked_modules_names = get_tracked_module_names(model=model) per_sample_gradients = [] set_mode(model, ModuleMode.PRECONDITION_GRADIENT) @@ -155,7 +157,7 @@ def test_for_loop_per_sample_gradient_equivalence( loss.backward() if test_name == "repeated_mlp": - finalize_preconditioned_gradient(model=model) + finalize_iteration(model=model, tracked_module_names=tracked_modules_names) module_gradients = {} for module in model.modules(): @@ -191,6 +193,7 @@ def test_for_loop_per_sample_gradient_equivalence( "conv", "conv_bn", "bert", + "roberta", "gpt", ], ) @@ -230,8 +233,9 @@ def test_mean_gradient_equivalence( strategy="identity", ) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True update_factor_args(model=model, factor_args=factor_args) + tracked_modules_names = get_tracked_module_names(model=model) per_sample_gradients = [] set_mode(model, ModuleMode.PRECONDITION_GRADIENT) @@ -247,7 +251,7 @@ def test_mean_gradient_equivalence( loss.backward() if test_name == "repeated_mlp": - finalize_preconditioned_gradient(model=model) + finalize_iteration(model=model, tracked_module_names=tracked_modules_names) module_gradients = {} for module in model.modules(): @@ -304,8 +308,9 @@ def test_mean_gradient_equivalence( "test_name", [ "mlp", + "repeated_mlp", "conv", - "gpt", + "roberta", ], ) @pytest.mark.parametrize("train_size", [32]) @@ -351,6 +356,8 @@ def test_lambda_equivalence( strategy="diagonal", use_empirical_fisher=True, ) + if test_name == "repeated_mlp": + factor_args.has_shared_parameters = True analyzer.fit_lambda_matrices( factors_name=f"pytest_{test_name}_lambda_diag", dataset=train_dataset, @@ -386,224 +393,3 @@ def test_lambda_equivalence( atol=ATOL, rtol=RTOL, ) - - -def test_precondition_gradient( - seed: int = 0, -) -> None: - input_dim = 128 - output_dim = 256 - batch_dim = 8 - lambda_scale = 1000 - damping = 1e-08 - - set_seed(seed) - A = torch.rand(size=(input_dim, input_dim), dtype=torch.float64) - B = torch.rand(size=(output_dim, output_dim), dtype=torch.float64) - Lambda = torch.rand(size=(output_dim, input_dim), dtype=torch.float64) - gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) - - start_time = time.time() - rotated_gradient = torch.einsum( - "ij,bjl,lk->bik", - ( - B.t(), - gradient, - A, - ), - ) - rotated_gradient.div_(Lambda + damping) - results = lambda_scale * torch.einsum( - "ij,bjl,lk->bik", - (B, rotated_gradient, A.t()), - ) - print(f"Took {time.time() - start_time} seconds.") - - start_time = time.time() - grads_rot = torch.matmul( - B.t(), - torch.matmul( - gradient, - A, - ), - ) - scaled_lambda = Lambda / lambda_scale - grads_rot.div_(scaled_lambda) - raw_results = torch.matmul( - B, - torch.matmul( - grads_rot, - A.t(), - ), - ) - print(f"Took {time.time() - start_time} seconds.") - - assert torch.allclose(raw_results, results, atol=1e-5, rtol=1e-3) - - -def test_query_gradient_svd( - seed: int = 0, -) -> None: - input_dim = 2048 - output_dim = 1024 - batch_dim = 16 - set_seed(seed) - - gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) - - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - assert torch.allclose(gradient, U @ torch.diag_embed(S) @ V, atol=1e-5, rtol=1e-3) - - rank = 32 - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.bmm(left, right).shape == gradient.shape - - rank = input_dim - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.allclose(torch.bmm(left, right), gradient, atol=1e-5, rtol=1e-3) - - rank = 32 - lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank), dtype=torch.float64) - lr_gradient2 = torch.rand(size=(batch_dim, rank, input_dim), dtype=torch.float64) - gradient = torch.bmm(lr_gradient1, lr_gradient2) - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.allclose(torch.bmm(left_mat, right_mat), gradient, atol=1e-5, rtol=1e-3) - - query_batch_dim = 32 - new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) - score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) - - lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - assert torch.allclose(score, lr_score) - - lr_score_reconst_matmul = torch.matmul( - torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() - ) - assert torch.allclose(score, lr_score_reconst_matmul) - - # These should be able to avoid explicit reconstruction. - # This should be used when input_dim > output_dim. - intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) - final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) - assert torch.allclose(score, final) - print("Option 1") - print(intermediate.numel()) - - # This should be used when output_dim > input_dim. - intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) - final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) - assert torch.allclose(score, final2) - print("Option 2") - print(intermediate2.numel()) - - print("Reconstruction") - print((torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel()) - path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - print(path) - - -@pytest.mark.parametrize("input_dim", [256, 512]) -@pytest.mark.parametrize("output_dim", [512, 1024]) -@pytest.mark.parametrize("batch_dim", [8, 16]) -@pytest.mark.parametrize("qbatch_dim", [8, 16]) -@pytest.mark.parametrize("rank", [32]) -@pytest.mark.parametrize("seed", [0]) -def test_query_gradient_svd_reconst( - input_dim: int, - output_dim: int, - batch_dim: int, - qbatch_dim: int, - rank: int, - seed: int, -) -> None: - set_seed(seed) - - lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank + 50), dtype=torch.float64) - lr_gradient2 = torch.rand(size=(batch_dim, rank + 50, input_dim), dtype=torch.float64) - gradient = torch.bmm(lr_gradient1, lr_gradient2) - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - new_gradient = torch.rand(size=(qbatch_dim, output_dim, input_dim), dtype=torch.float64) - - lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - lr_score_reconst_matmul = torch.matmul( - torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() - ) - assert torch.allclose(lr_score, lr_score_reconst_matmul) - - # This should be used when input_dim > output_dim. - intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) - final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) - assert torch.allclose(lr_score, final) - print("Option 1") - print(intermediate.numel()) - - # This should be used when output_dim > input_dim. - intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) - final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) - assert torch.allclose(lr_score, final2) - print("Option 2") - print(intermediate2.numel()) - - print("Reconstruction") - reconst_numel = (torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel() - print(reconst_numel) - path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - print(path) - - if new_gradient.size(0) * right_mat.size(0) * rank * min((right_mat.size(2), left_mat.size(1))) > right_mat.size( - 0 - ) * right_mat.size(2) * left_mat.size(1): - assert intermediate2.numel() > reconst_numel and intermediate.numel() > reconst_numel - elif output_dim >= input_dim: - assert intermediate2.numel() <= reconst_numel - else: - assert intermediate.numel() <= reconst_numel - - -def test_compute_score_matmul( - seed: int = 0, -) -> None: - input_dim = 1024 - output_dim = 2048 - batch_dim = 16 - query_batch_dim = 64 - set_seed(seed) - - gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) - new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) - - score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) - path = opt_einsum.contract_path("toi,qoi->tq", gradient, new_gradient) - print(path) - - unsqueeze_score = opt_einsum.contract("t...,q...->tq", gradient, new_gradient) - assert torch.allclose(score, unsqueeze_score) - path = opt_einsum.contract_path("t...,q...->tq", gradient, new_gradient) - print(path) diff --git a/tests/scores/test_pairwise_scores.py b/tests/scores/test_pairwise_scores.py index 048309d..087627f 100644 --- a/tests/scores/test_pairwise_scores.py +++ b/tests/scores/test_pairwise_scores.py @@ -4,16 +4,20 @@ import pytest import torch +from scipy.stats import spearmanr -from kronfluence.arguments import FactorArguments, ScoreArguments -from kronfluence.utils.common.factor_arguments import test_factor_arguments -from kronfluence.utils.common.score_arguments import test_score_arguments +from kronfluence.arguments import ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( ATOL, + DEFAULT_FACTORS_NAME, + DEFAULT_SCORES_NAME, RTOL, check_tensor_dict_equivalence, + custom_scores_name, prepare_model_and_analyzer, prepare_test, ) @@ -24,27 +28,27 @@ [ "mlp", "repeated_mlp", - "mlp_checkpoint", "conv", - "conv_bn", "bert", + "roberta", "gpt", + "gpt_checkpoint", ], ) -@pytest.mark.parametrize("score_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("query_gradient_rank", [None, 16]) +@pytest.mark.parametrize("score_dtype", [torch.float32]) +@pytest.mark.parametrize("query_gradient_low_rank", [None, 16]) @pytest.mark.parametrize("query_size", [16]) @pytest.mark.parametrize("train_size", [32]) @pytest.mark.parametrize("seed", [0]) def test_compute_pairwise_scores( test_name: str, score_dtype: torch.dtype, - query_gradient_rank: Optional[int], + query_gradient_low_rank: Optional[int], query_size: int, train_size: int, seed: int, ) -> None: - # Makes sure that the pairwise influence computations are working properly. + # Makes sure that pairwise influence computations are working properly. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -56,13 +60,12 @@ def test_compute_pairwise_scores( model=model, task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True - factors_name = f"pytest_{test_name}_{test_compute_pairwise_scores.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, @@ -72,12 +75,11 @@ def test_compute_pairwise_scores( score_args = ScoreArguments( score_dtype=score_dtype, - query_gradient_rank=query_gradient_rank, + query_gradient_low_rank=query_gradient_low_rank, ) - scores_name = f"pytest_{test_name}_{test_compute_pairwise_scores.__name__}_{query_gradient_rank}_scores" analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -87,38 +89,35 @@ def test_compute_pairwise_scores( overwrite_output_dir=True, ) - pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + pairwise_scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) assert pairwise_scores[ALL_MODULE_NAME].size(0) == query_size assert pairwise_scores[ALL_MODULE_NAME].size(1) == train_size assert pairwise_scores[ALL_MODULE_NAME].dtype == score_dtype -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - ], -) -@pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("score_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("query_gradient_rank", [None, 16]) -@pytest.mark.parametrize("damping", [None, 1e-08]) +@pytest.mark.parametrize("test_name", ["mlp"]) +@pytest.mark.parametrize("has_shared_parameters", [True, False]) +@pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("score_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("query_gradient_low_rank", [None, 16]) +@pytest.mark.parametrize("damping_factor", [None, 1e-08]) @pytest.mark.parametrize("query_size", [16]) @pytest.mark.parametrize("train_size", [32]) -@pytest.mark.parametrize("seed", [6]) +@pytest.mark.parametrize("seed", [1]) def test_compute_pairwise_scores_dtype( test_name: str, + has_shared_parameters: bool, per_sample_gradient_dtype: torch.dtype, precondition_dtype: torch.dtype, score_dtype: torch.dtype, - query_gradient_rank: Optional[int], - damping: Optional[float], + query_gradient_low_rank: Optional[int], + damping_factor: Optional[float], query_size: int, train_size: int, seed: int, ) -> None: - # Makes sure that the pairwise influence computations are working properly with different dtypes. + # Makes sure that pairwise influence computations are working properly with different data types. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -130,26 +129,28 @@ def test_compute_pairwise_scores_dtype( model=model, task=task, ) - factors_name = f"pytest_{test_name}_{test_compute_pairwise_scores_dtype.__name__}" + + factor_args = pytest_factor_arguments() + factor_args.has_shared_parameters = has_shared_parameters analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, + factor_args=factor_args, dataloader_kwargs=kwargs, per_device_batch_size=32, overwrite_output_dir=True, ) score_args = ScoreArguments( - damping=damping, + damping_factor=damping_factor, score_dtype=score_dtype, - query_gradient_rank=query_gradient_rank, + query_gradient_low_rank=query_gradient_low_rank, per_sample_gradient_dtype=per_sample_gradient_dtype, precondition_dtype=precondition_dtype, ) - scores_name = f"pytest_{test_name}_{test_compute_pairwise_scores_dtype.__name__}_{query_gradient_rank}_scores" analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -159,7 +160,7 @@ def test_compute_pairwise_scores_dtype( overwrite_output_dir=True, ) - pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + pairwise_scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) assert pairwise_scores[ALL_MODULE_NAME].size(0) == query_size assert pairwise_scores[ALL_MODULE_NAME].size(1) == train_size assert pairwise_scores[ALL_MODULE_NAME].dtype == score_dtype @@ -169,14 +170,13 @@ def test_compute_pairwise_scores_dtype( "test_name", [ "mlp", - "conv", - "gpt", + "conv_bn", ], ) @pytest.mark.parametrize("strategy", ["identity", "diagonal", "kfac", "ekfac"]) @pytest.mark.parametrize("query_size", [20]) @pytest.mark.parametrize("train_size", [50]) -@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("seed", [2]) def test_pairwise_scores_batch_size_equivalence( test_name: str, strategy: str, @@ -198,12 +198,9 @@ def test_pairwise_scores_batch_size_equivalence( task=task, ) - factor_args = FactorArguments( - strategy=strategy, - ) - factors_name = f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}" + factor_args = pytest_factor_arguments(strategy=strategy) analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=4, @@ -211,10 +208,10 @@ def test_pairwise_scores_batch_size_equivalence( overwrite_output_dir=True, ) - score_args = test_score_arguments() + score_args = pytest_score_arguments() analyzer.compute_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_bs1", - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -224,12 +221,12 @@ def test_pairwise_scores_batch_size_equivalence( overwrite_output_dir=True, ) bs1_scores = analyzer.load_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_bs1", + scores_name=DEFAULT_SCORES_NAME, ) analyzer.compute_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_bs8", - factors_name=factors_name, + scores_name=custom_scores_name("bs8"), + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=3, train_dataset=train_dataset, @@ -239,7 +236,7 @@ def test_pairwise_scores_batch_size_equivalence( overwrite_output_dir=True, ) bs8_scores = analyzer.load_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_bs8", + scores_name=custom_scores_name("bs8"), ) assert check_tensor_dict_equivalence( @@ -250,8 +247,8 @@ def test_pairwise_scores_batch_size_equivalence( ) analyzer.compute_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_auto", - factors_name=factors_name, + scores_name=custom_scores_name("auto"), + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=10, train_dataset=train_dataset, @@ -261,7 +258,7 @@ def test_pairwise_scores_batch_size_equivalence( overwrite_output_dir=True, ) bs_auto_scores = analyzer.load_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_auto", + scores_name=custom_scores_name("auto"), ) assert check_tensor_dict_equivalence( @@ -277,19 +274,22 @@ def test_pairwise_scores_batch_size_equivalence( [ "mlp", "conv", + "gpt", ], ) -@pytest.mark.parametrize("data_partition_size", [1, 4]) -@pytest.mark.parametrize("module_partition_size", [1, 3]) -@pytest.mark.parametrize("per_module_score", [True, False]) +@pytest.mark.parametrize("data_partitions", [2, 4]) +@pytest.mark.parametrize("module_partitions", [2, 3]) +@pytest.mark.parametrize("compute_per_module_scores", [True, False]) +@pytest.mark.parametrize("compute_per_token_scores", [True, False]) @pytest.mark.parametrize("query_size", [32]) @pytest.mark.parametrize("train_size", [64]) -@pytest.mark.parametrize("seed", [2]) +@pytest.mark.parametrize("seed", [3]) def test_pairwise_scores_partition_equivalence( test_name: str, - data_partition_size: int, - module_partition_size: int, - per_module_score: bool, + data_partitions: int, + module_partitions: int, + compute_per_module_scores: bool, + compute_per_token_scores: bool, query_size: int, train_size: int, seed: int, @@ -308,21 +308,20 @@ def test_pairwise_scores_partition_equivalence( task=task, ) - factors_name = f"pytest_{test_name}_{test_pairwise_scores_partition_equivalence.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_pairwise_scores_partition_equivalence.__name__}_scores" - score_args = test_score_arguments() - score_args.per_module_score = per_module_score + score_args = pytest_score_arguments() + score_args.compute_per_module_scores = compute_per_module_scores + score_args.compute_per_token_scores = compute_per_token_scores analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -331,13 +330,13 @@ def test_pairwise_scores_partition_equivalence( score_args=score_args, overwrite_output_dir=True, ) - scores = analyzer.load_pairwise_scores(scores_name=scores_name) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) - score_args.data_partition_size = data_partition_size - score_args.module_partition_size = module_partition_size + score_args.data_partitions = data_partitions + score_args.module_partitions = module_partitions analyzer.compute_pairwise_scores( - scores_name=f"pytest_{test_name}_partition_{data_partition_size}_{module_partition_size}", - factors_name=factors_name, + scores_name=custom_scores_name(f"{data_partitions}_{module_partitions}"), + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=10, train_dataset=train_dataset, @@ -347,7 +346,7 @@ def test_pairwise_scores_partition_equivalence( overwrite_output_dir=True, ) partitioned_scores = analyzer.load_pairwise_scores( - scores_name=f"pytest_{test_name}_partition_{data_partition_size}_{module_partition_size}", + scores_name=custom_scores_name(f"{data_partitions}_{module_partitions}"), ) assert check_tensor_dict_equivalence( @@ -368,7 +367,7 @@ def test_pairwise_scores_partition_equivalence( ) @pytest.mark.parametrize("query_size", [32]) @pytest.mark.parametrize("train_size", [64]) -@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize("seed", [4]) def test_per_module_scores_equivalence( test_name: str, query_size: int, @@ -389,20 +388,18 @@ def test_per_module_scores_equivalence( task=task, ) - factors_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" - score_args = test_score_arguments() + score_args = pytest_score_arguments() analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -411,12 +408,12 @@ def test_per_module_scores_equivalence( score_args=score_args, overwrite_output_dir=True, ) - scores = analyzer.load_pairwise_scores(scores_name=scores_name) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) - score_args.per_module_score = True + score_args.compute_per_module_scores = True analyzer.compute_pairwise_scores( - scores_name=scores_name + "_per_module", - factors_name=factors_name, + scores_name=custom_scores_name("per_module"), + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -425,7 +422,7 @@ def test_per_module_scores_equivalence( score_args=score_args, overwrite_output_dir=True, ) - per_module_scores = analyzer.load_pairwise_scores(scores_name=scores_name + "_per_module") + per_module_scores = analyzer.load_pairwise_scores(scores_name=custom_scores_name("per_module")) total_scores = None for module_name in per_module_scores: @@ -437,6 +434,76 @@ def test_per_module_scores_equivalence( assert torch.allclose(total_scores, scores[ALL_MODULE_NAME], atol=ATOL, rtol=RTOL) +@pytest.mark.parametrize("test_name", ["mlp", "conv", "gpt"]) +@pytest.mark.parametrize("compute_per_module_scores", [True, False]) +@pytest.mark.parametrize("query_size", [12]) +@pytest.mark.parametrize("train_size", [64]) +@pytest.mark.parametrize("seed", [5]) +def test_per_token_scores_equivalence( + test_name: str, + compute_per_module_scores: bool, + query_size: int, + train_size: int, + seed: int, +) -> None: + # Influence scores should be identical with and without per token score computations. + model, train_dataset, test_dataset, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.compute_per_module_scores = compute_per_module_scores + analyzer.compute_pairwise_scores( + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) + + score_args.compute_per_token_scores = True + analyzer.compute_pairwise_scores( + scores_name=custom_scores_name("per_token"), + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + per_token_scores = analyzer.load_pairwise_scores(scores_name=custom_scores_name("per_token")) + + for module_name in per_token_scores: + if test_name == "gpt": + assert torch.allclose(per_token_scores[module_name].sum(dim=-1), scores[module_name], atol=ATOL, rtol=RTOL) + else: + assert torch.allclose(per_token_scores[module_name], scores[module_name], atol=ATOL, rtol=RTOL) + + @pytest.mark.parametrize( "test_name", [ @@ -444,11 +511,17 @@ def test_per_module_scores_equivalence( "conv_bn", ], ) +@pytest.mark.parametrize("data_partitions", [1, 2]) +@pytest.mark.parametrize("query_gradient_low_rank", [None, 4]) +@pytest.mark.parametrize("query_gradient_accumulation_steps", [1, 4]) @pytest.mark.parametrize("query_size", [60]) @pytest.mark.parametrize("train_size", [60]) -@pytest.mark.parametrize("seed", [2]) +@pytest.mark.parametrize("seed", [6]) def test_compute_pairwise_scores_with_indices( test_name: str, + data_partitions: int, + query_gradient_low_rank: Optional[int], + query_gradient_accumulation_steps: int, query_size: int, train_size: int, seed: int, @@ -465,21 +538,21 @@ def test_compute_pairwise_scores_with_indices( model=model, task=task, ) - factors_name = f"pytest_{test_name}_{test_compute_pairwise_scores_with_indices.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=32, overwrite_output_dir=True, ) - score_args = test_score_arguments() - score_args.data_partition_size = 2 - scores_name = f"pytest_{test_name}_{test_compute_pairwise_scores_with_indices.__name__}_scores" + score_args = pytest_score_arguments() + score_args.data_partitions = data_partitions + score_args.query_gradient_low_rank = query_gradient_low_rank + score_args.query_gradient_accumulation_steps = query_gradient_accumulation_steps analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, query_indices=list(range(30)), per_device_query_batch_size=4, @@ -491,7 +564,7 @@ def test_compute_pairwise_scores_with_indices( overwrite_output_dir=True, ) - pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + pairwise_scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) assert pairwise_scores[ALL_MODULE_NAME].size(0) == 30 assert pairwise_scores[ALL_MODULE_NAME].size(1) == 50 @@ -500,18 +573,20 @@ def test_compute_pairwise_scores_with_indices( "test_name", [ "mlp", - "conv", + "conv_bn", ], ) @pytest.mark.parametrize("query_size", [64]) @pytest.mark.parametrize("train_size", [32]) -@pytest.mark.parametrize("num_query_gradient_accumulations", [2, 5]) -@pytest.mark.parametrize("seed", [5]) -def test_query_accumulation( +@pytest.mark.parametrize("query_gradient_low_rank", [None]) +@pytest.mark.parametrize("query_gradient_accumulation_steps", [2, 5]) +@pytest.mark.parametrize("seed", [7]) +def test_query_accumulation_steps( test_name: str, query_size: int, train_size: int, - num_query_gradient_accumulations: int, + query_gradient_low_rank: Optional[int], + query_gradient_accumulation_steps: int, seed: int, ) -> None: # Makes sure the query accumulation is correctly implemented. @@ -528,20 +603,18 @@ def test_query_accumulation( task=task, ) - factors_name = f"pytest_{test_name}_{test_query_accumulation.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_query_accumulation.__name__}_scores" - score_args = test_score_arguments(query_gradient_rank=8) + score_args = pytest_score_arguments(query_gradient_low_rank=query_gradient_low_rank) analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, @@ -550,12 +623,12 @@ def test_query_accumulation( score_args=score_args, overwrite_output_dir=True, ) - scores = analyzer.load_pairwise_scores(scores_name=scores_name) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) - score_args.num_query_gradient_accumulations = num_query_gradient_accumulations + score_args.query_gradient_accumulation_steps = query_gradient_accumulation_steps analyzer.compute_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_query_accumulation.__name__}_accumulated_scores", - factors_name=factors_name, + scores_name=custom_scores_name("accumulation"), + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=10, train_dataset=train_dataset, @@ -564,13 +637,184 @@ def test_query_accumulation( score_args=score_args, overwrite_output_dir=True, ) - partitioned_scores = analyzer.load_pairwise_scores( - scores_name=f"pytest_{test_name}_{test_query_accumulation.__name__}_accumulated_scores", + accumulated_scores = analyzer.load_pairwise_scores( + scores_name=custom_scores_name("accumulation"), ) assert check_tensor_dict_equivalence( scores, - partitioned_scores, + accumulated_scores, + atol=ATOL, + rtol=RTOL, + ) + + +@pytest.mark.parametrize( + "test_name", + ["mlp", "conv"], +) +@pytest.mark.parametrize("query_size", [50]) +@pytest.mark.parametrize("train_size", [32]) +@pytest.mark.parametrize("data_partitions", [3]) +@pytest.mark.parametrize("module_partitions", [3]) +@pytest.mark.parametrize("query_gradient_low_rank", [None]) +@pytest.mark.parametrize("seed", [8]) +def test_query_gradient_aggregation( + test_name: str, + query_size: int, + train_size: int, + data_partitions: int, + module_partitions: int, + query_gradient_low_rank: Optional[int], + seed: int, +) -> None: + # Makes sure the query gradient aggregation is correctly implemented. + model, train_dataset, test_dataset, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = pytest_factor_arguments() + if test_name == "repeated_mlp": + factor_args.has_shared_parameters = True + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + factor_args=factor_args, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments(query_gradient_low_rank=query_gradient_low_rank) + analyzer.compute_pairwise_scores( + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) + + score_args.aggregate_query_gradients = True + score_args.data_partitions = data_partitions + score_args.module_partitions = data_partitions + analyzer.compute_pairwise_scores( + scores_name=custom_scores_name("aggregation"), + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=10, + train_dataset=train_dataset, + per_device_train_batch_size=5, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + aggregated_scores = analyzer.load_pairwise_scores( + scores_name=custom_scores_name("aggregation"), + ) + + assert aggregated_scores[ALL_MODULE_NAME].shape[0] == 1 + assert torch.allclose( + scores[ALL_MODULE_NAME].sum(dim=0, keepdim=True), + aggregated_scores[ALL_MODULE_NAME], + atol=ATOL, + rtol=RTOL, + ) + + +@pytest.mark.parametrize( + "test_name", + ["mlp", "conv"], +) +@pytest.mark.parametrize("query_size", [64]) +@pytest.mark.parametrize("train_size", [32]) +@pytest.mark.parametrize("data_partitions", [3]) +@pytest.mark.parametrize("module_partitions", [2]) +@pytest.mark.parametrize("aggregate_query_gradients", [True, False]) +@pytest.mark.parametrize("query_gradient_low_rank", [None]) +@pytest.mark.parametrize("seed", [9]) +def test_train_gradient_aggregation( + test_name: str, + query_size: int, + train_size: int, + data_partitions: int, + module_partitions: int, + aggregate_query_gradients: bool, + query_gradient_low_rank: Optional[int], + seed: int, +) -> None: + # Makes sure the train gradient aggregation is correctly implemented. + model, train_dataset, test_dataset, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments(query_gradient_low_rank=query_gradient_low_rank) + score_args.aggregate_query_gradients = aggregate_query_gradients + analyzer.compute_pairwise_scores( + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) + + score_args.aggregate_train_gradients = True + score_args.data_partitions = data_partitions + score_args.module_partitions = module_partitions + analyzer.compute_pairwise_scores( + scores_name=custom_scores_name("aggregation"), + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=10, + train_dataset=train_dataset, + per_device_train_batch_size=5, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + aggregated_scores = analyzer.load_pairwise_scores( + scores_name=custom_scores_name("aggregation"), + ) + + assert aggregated_scores[ALL_MODULE_NAME].shape[1] == 1 + assert torch.allclose( + scores[ALL_MODULE_NAME].sum(dim=1, keepdim=True), + aggregated_scores[ALL_MODULE_NAME], atol=ATOL, rtol=RTOL, ) @@ -581,18 +825,19 @@ def test_query_accumulation( [ "mlp", "conv", + "roberta", ], ) @pytest.mark.parametrize("query_size", [50]) @pytest.mark.parametrize("train_size", [32]) -@pytest.mark.parametrize("seed", [8]) +@pytest.mark.parametrize("seed", [10]) def test_pairwise_shared_parameters( test_name: str, query_size: int, train_size: int, seed: int, ) -> None: - # Makes sure the scores are identical with and without `shared_parameters_exist` flag. + # Makes sure the scores are identical with and without `has_shared_parameters` flag. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -605,52 +850,50 @@ def test_pairwise_shared_parameters( model=model, task=task, ) - factor_args = test_factor_arguments() - factor_args.shared_parameters_exist = False - factors_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}" + factor_args = pytest_factor_arguments() + score_args = pytest_score_arguments() analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}_scores" analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, + score_args=score_args, per_device_train_batch_size=8, dataloader_kwargs=kwargs, overwrite_output_dir=True, ) - scores = analyzer.load_pairwise_scores(scores_name=scores_name) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) - factor_args.shared_parameters_exist = True - factors_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}_shared" + factor_args.has_shared_parameters = True analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}_shared_scores" analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=custom_scores_name("shared"), + factors_name=DEFAULT_FACTORS_NAME, query_dataset=test_dataset, per_device_query_batch_size=4, train_dataset=train_dataset, + score_args=score_args, per_device_train_batch_size=8, dataloader_kwargs=kwargs, overwrite_output_dir=True, ) - shared_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + shared_scores = analyzer.load_pairwise_scores(scores_name=custom_scores_name("shared")) assert check_tensor_dict_equivalence( scores, @@ -658,3 +901,78 @@ def test_pairwise_shared_parameters( atol=ATOL, rtol=RTOL, ) + + +@pytest.mark.parametrize( + "test_name", + ["mlp", "conv_bn", "gpt"], +) +@pytest.mark.parametrize("query_gradient_low_rank", [16, 32]) +@pytest.mark.parametrize("use_full_svd", [False, True]) +@pytest.mark.parametrize("query_gradient_accumulation_steps", [1, 3]) +@pytest.mark.parametrize("query_size", [64]) +@pytest.mark.parametrize("train_size", [160]) +@pytest.mark.parametrize("seed", [11]) +def test_pairwise_query_batching( + test_name: str, + query_gradient_low_rank: int, + use_full_svd: bool, + query_gradient_accumulation_steps: int, + query_size: int, + train_size: int, + seed: int, +) -> None: + # Makes sure similar results are obtained with and without query batching. + model, train_dataset, test_dataset, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + factor_args = pytest_factor_arguments() + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + factor_args=factor_args, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + score_args = pytest_score_arguments() + analyzer.compute_pairwise_scores( + scores_name=DEFAULT_SCORES_NAME, + score_args=score_args, + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) + + score_args.query_gradient_low_rank = query_gradient_low_rank + score_args.use_full_svd = use_full_svd + score_args.query_gradient_accumulation_steps = query_gradient_accumulation_steps + analyzer.compute_pairwise_scores( + scores_name=custom_scores_name("qb"), + score_args=score_args, + factors_name=DEFAULT_FACTORS_NAME, + query_dataset=test_dataset, + per_device_query_batch_size=3, + train_dataset=train_dataset, + per_device_train_batch_size=9, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + qb_scores = analyzer.load_pairwise_scores(scores_name=custom_scores_name("qb")) + + for i in range(query_size): + assert spearmanr(scores[ALL_MODULE_NAME][i], qb_scores[ALL_MODULE_NAME][i])[0] > 0.9 diff --git a/tests/scores/test_self_scores.py b/tests/scores/test_self_scores.py index e94ebda..c0ee190 100644 --- a/tests/scores/test_self_scores.py +++ b/tests/scores/test_self_scores.py @@ -1,20 +1,26 @@ # pylint: skip-file +from typing import Optional + import pytest import torch from kronfluence.arguments import ScoreArguments from kronfluence.utils.common.factor_arguments import ( default_factor_arguments, - test_factor_arguments, + pytest_factor_arguments, ) -from kronfluence.utils.common.score_arguments import test_score_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( ATOL, + DEFAULT_FACTORS_NAME, + DEFAULT_SCORES_NAME, RTOL, check_tensor_dict_equivalence, + custom_factors_name, + custom_scores_name, prepare_model_and_analyzer, prepare_test, ) @@ -25,11 +31,11 @@ [ "mlp", "repeated_mlp", - "mlp_checkpoint", "conv", - "conv_bn", "bert", + "roberta", "gpt", + "gpt_checkpoint", ], ) @pytest.mark.parametrize("use_measurement_for_self_influence", [False, True]) @@ -43,7 +49,7 @@ def test_compute_self_scores( train_size: int, seed: int, ) -> None: - # Makes sure that the self-influence computations are working properly. + # Makes sure that self-influence computations are working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -56,11 +62,10 @@ def test_compute_self_scores( ) factor_args = default_factor_arguments() if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True - factors_name = f"pytest_{test_name}_{test_compute_self_scores.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, @@ -72,10 +77,9 @@ def test_compute_self_scores( use_measurement_for_self_influence=use_measurement_for_self_influence, score_dtype=score_dtype, ) - scores_name = f"pytest_{test_name}_{test_compute_self_scores.__name__}_scores" analyzer.compute_self_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=4, dataloader_kwargs=kwargs, @@ -83,30 +87,31 @@ def test_compute_self_scores( overwrite_output_dir=True, ) - self_scores = analyzer.load_self_scores(scores_name=scores_name) + self_scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) assert self_scores[ALL_MODULE_NAME].size(0) == train_size assert len(self_scores[ALL_MODULE_NAME].shape) == 1 assert self_scores[ALL_MODULE_NAME].dtype == score_dtype -@pytest.mark.parametrize( - "test_name", - ["mlp"], -) -@pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("score_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("test_name", ["mlp"]) +@pytest.mark.parametrize("has_shared_parameters", [True, False]) +@pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("score_dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("damping_factor", [None, 1e-08]) @pytest.mark.parametrize("train_size", [32]) -@pytest.mark.parametrize("seed", [6]) +@pytest.mark.parametrize("seed", [1]) def test_compute_self_scores_dtype( test_name: str, + has_shared_parameters: bool, per_sample_gradient_dtype: torch.dtype, precondition_dtype: torch.dtype, score_dtype: torch.dtype, + damping_factor: Optional[float], train_size: int, seed: int, ) -> None: - # Make sure that the self-influence computations are working properly with different dtypes. + # Makes sure that self-influence computations are working properly with different data types. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, query_size=10, @@ -118,10 +123,13 @@ def test_compute_self_scores_dtype( model=model, task=task, ) - factors_name = f"pytest_{test_name}_{test_compute_self_scores_dtype.__name__}" + + factor_args = pytest_factor_arguments() + factor_args.has_shared_parameters = has_shared_parameters analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, + factor_args=factor_args, dataloader_kwargs=kwargs, per_device_batch_size=32, overwrite_output_dir=True, @@ -131,18 +139,18 @@ def test_compute_self_scores_dtype( score_dtype=score_dtype, per_sample_gradient_dtype=per_sample_gradient_dtype, precondition_dtype=precondition_dtype, + damping_factor=damping_factor, ) - scores_name = f"pytest_{test_name}_{test_compute_self_scores_dtype.__name__}_scores" analyzer.compute_self_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, ) - self_scores = analyzer.load_self_scores(scores_name=scores_name) + self_scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) assert self_scores[ALL_MODULE_NAME].size(0) == train_size assert len(self_scores[ALL_MODULE_NAME].shape) == 1 assert self_scores[ALL_MODULE_NAME].dtype == score_dtype @@ -152,13 +160,12 @@ def test_compute_self_scores_dtype( "test_name", [ "mlp", - "conv", - "gpt", + "conv_bn", ], ) @pytest.mark.parametrize("strategy", ["identity", "diagonal", "kfac", "ekfac"]) -@pytest.mark.parametrize("train_size", [50]) -@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("train_size", [49]) +@pytest.mark.parametrize("seed", [2]) def test_self_scores_batch_size_equivalence( test_name: str, strategy: str, @@ -178,10 +185,9 @@ def test_self_scores_batch_size_equivalence( task=task, ) - factor_args = test_factor_arguments(strategy=strategy) - factors_name = f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}" + factor_args = pytest_factor_arguments(strategy=strategy) analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=4, @@ -189,10 +195,10 @@ def test_self_scores_batch_size_equivalence( overwrite_output_dir=True, ) - score_args = test_score_arguments() + score_args = pytest_score_arguments() analyzer.compute_self_scores( - scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_bs1", - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=1, dataloader_kwargs=kwargs, @@ -200,12 +206,12 @@ def test_self_scores_batch_size_equivalence( overwrite_output_dir=True, ) bs1_scores = analyzer.load_self_scores( - scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_bs1", + scores_name=DEFAULT_SCORES_NAME, ) analyzer.compute_self_scores( - scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_bs8", - factors_name=factors_name, + scores_name=custom_scores_name("bs8"), + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, @@ -213,7 +219,7 @@ def test_self_scores_batch_size_equivalence( overwrite_output_dir=True, ) bs8_scores = analyzer.load_self_scores( - scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_bs8", + scores_name=custom_scores_name("bs8"), ) assert check_tensor_dict_equivalence( @@ -224,8 +230,8 @@ def test_self_scores_batch_size_equivalence( ) analyzer.compute_self_scores( - scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_auto", - factors_name=factors_name, + scores_name=custom_scores_name("auto"), + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=None, dataloader_kwargs=kwargs, @@ -233,7 +239,7 @@ def test_self_scores_batch_size_equivalence( overwrite_output_dir=True, ) bs_auto_scores = analyzer.load_self_scores( - scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_auto", + scores_name=custom_scores_name("auto"), ) assert check_tensor_dict_equivalence( @@ -252,14 +258,16 @@ def test_self_scores_batch_size_equivalence( "gpt", ], ) -@pytest.mark.parametrize("data_partition_size", [1, 4]) -@pytest.mark.parametrize("module_partition_size", [1, 3]) +@pytest.mark.parametrize("data_partitions", [2, 4]) +@pytest.mark.parametrize("module_partitions", [2, 3]) +@pytest.mark.parametrize("compute_per_module_scores", [True, False]) @pytest.mark.parametrize("train_size", [64]) -@pytest.mark.parametrize("seed", [2]) +@pytest.mark.parametrize("seed", [3]) def test_self_scores_partition_equivalence( test_name: str, - data_partition_size: int, - module_partition_size: int, + data_partitions: int, + module_partitions: int, + compute_per_module_scores: bool, train_size: int, seed: int, ) -> None: @@ -276,33 +284,32 @@ def test_self_scores_partition_equivalence( task=task, ) - factors_name = f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}_scores" - score_args = test_score_arguments() + score_args = pytest_score_arguments() + score_args.compute_per_module_scores = compute_per_module_scores analyzer.compute_self_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, ) - scores = analyzer.load_self_scores(scores_name=scores_name) + scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) - score_args.data_partition_size = data_partition_size - score_args.module_partition_size = module_partition_size + score_args.data_partitions = data_partitions + score_args.module_partitions = module_partitions analyzer.compute_self_scores( - scores_name=f"pytest_{test_name}_partition_{data_partition_size}_{module_partition_size}", - factors_name=factors_name, + scores_name=custom_scores_name(f"{data_partitions}_{module_partitions}"), + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=5, dataloader_kwargs=kwargs, @@ -310,7 +317,7 @@ def test_self_scores_partition_equivalence( overwrite_output_dir=True, ) partitioned_scores = analyzer.load_self_scores( - scores_name=f"pytest_{test_name}_partition_{data_partition_size}_{module_partition_size}", + scores_name=custom_scores_name(f"{data_partitions}_{module_partitions}"), ) assert check_tensor_dict_equivalence( @@ -325,7 +332,7 @@ def test_self_scores_partition_equivalence( "test_name", [ "mlp", - "conv", + "conv_bn", "gpt", ], ) @@ -349,39 +356,37 @@ def test_per_module_scores_equivalence( task=task, ) - factors_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" - score_args = test_score_arguments() + score_args = pytest_score_arguments() analyzer.compute_self_scores( - scores_name=scores_name, + scores_name=DEFAULT_SCORES_NAME, score_args=score_args, - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, overwrite_output_dir=True, ) - scores = analyzer.load_self_scores(scores_name=scores_name) + scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) score_args.per_module_score = True analyzer.compute_self_scores( - scores_name=scores_name + "_per_module", - factors_name=factors_name, + scores_name=custom_scores_name("per_module"), + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, ) - per_module_scores = analyzer.load_self_scores(scores_name=scores_name + "_per_module") + per_module_scores = analyzer.load_self_scores(custom_scores_name("per_module")) total_scores = None for module_name in per_module_scores: @@ -398,13 +403,14 @@ def test_per_module_scores_equivalence( [ "mlp", "conv_bn", - "gpt", ], ) +@pytest.mark.parametrize("data_partitions", [1, 2]) @pytest.mark.parametrize("train_size", [60]) -@pytest.mark.parametrize("seed", [7]) +@pytest.mark.parametrize("seed", [6]) def test_compute_self_scores_with_indices( test_name: str, + data_partitions: int, train_size: int, seed: int, ) -> None: @@ -419,21 +425,19 @@ def test_compute_self_scores_with_indices( model=model, task=task, ) - factors_name = f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=32, overwrite_output_dir=True, ) - score_args = test_score_arguments() - score_args.data_partition_size = 2 - scores_name = f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}_scores" + score_args = pytest_score_arguments() + score_args.data_partitions = data_partitions analyzer.compute_self_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, train_indices=list(range(48)), per_device_train_batch_size=8, @@ -442,18 +446,13 @@ def test_compute_self_scores_with_indices( overwrite_output_dir=True, ) - self_scores = analyzer.load_self_scores(scores_name=scores_name) + self_scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) assert self_scores[ALL_MODULE_NAME].size(0) == 48 -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - ], -) +@pytest.mark.parametrize("test_name", ["mlp"]) @pytest.mark.parametrize("train_size", [60]) -@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("seed", [7]) def test_compute_self_scores_with_diagonal_pairwise_equivalence( test_name: str, train_size: int, @@ -471,31 +470,29 @@ def test_compute_self_scores_with_diagonal_pairwise_equivalence( model=model, task=task, ) - factors_name = f"pytest_{test_name}_{test_compute_self_scores_with_diagonal_pairwise_equivalence.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=32, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_compute_self_scores_with_diagonal_pairwise_equivalence.__name__}_scores" - score_args = test_score_arguments() + score_args = pytest_score_arguments() analyzer.compute_self_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, ) - self_scores = analyzer.load_self_scores(scores_name=scores_name) + self_scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, query_dataset=train_dataset, @@ -504,7 +501,7 @@ def test_compute_self_scores_with_diagonal_pairwise_equivalence( score_args=score_args, overwrite_output_dir=True, ) - pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + pairwise_scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) assert torch.allclose( torch.diag(pairwise_scores[ALL_MODULE_NAME]), @@ -519,7 +516,7 @@ def test_compute_self_scores_with_diagonal_pairwise_equivalence( ["mlp", "conv", "conv_bn", "wrong_conv"], ) @pytest.mark.parametrize("train_size", [24]) -@pytest.mark.parametrize("seed", [7]) +@pytest.mark.parametrize("seed", [8]) def test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence( test_name: str, train_size: int, @@ -537,36 +534,30 @@ def test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence( model=model, task=task, ) - factors_name = ( - f"pytest_{test_name}_{test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence.__name__}" - ) analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=32, overwrite_output_dir=True, ) - scores_name = ( - f"pytest_{test_name}_{test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence.__name__}_scores" - ) - score_args = test_score_arguments() + score_args = pytest_score_arguments() score_args.use_measurement_for_self_influence = True analyzer.compute_self_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, ) - self_scores = analyzer.load_self_scores(scores_name=scores_name) + self_scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) analyzer.compute_pairwise_scores( - scores_name=scores_name, - factors_name=factors_name, + scores_name=DEFAULT_SCORES_NAME, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, query_dataset=train_dataset, @@ -575,7 +566,7 @@ def test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence( score_args=score_args, overwrite_output_dir=True, ) - pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + pairwise_scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) assert torch.allclose( torch.diag(pairwise_scores[ALL_MODULE_NAME]), @@ -590,6 +581,7 @@ def test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence( [ "mlp", "conv", + "roberta", ], ) @pytest.mark.parametrize("use_measurement_for_self_influence", [False, True]) @@ -603,7 +595,7 @@ def test_self_shared_parameters( train_size: int, seed: int, ) -> None: - # Makes sure the scores are identical with and without `shared_parameters_exist` flag. + # Makes sure the scores are identical with and without `has_shared_parameters` flag. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -616,52 +608,48 @@ def test_self_shared_parameters( model=model, task=task, ) - factor_args = test_factor_arguments() - factor_args.shared_parameters_exist = False - score_args = test_score_arguments() + factor_args = pytest_factor_arguments() + factor_args.has_shared_parameters = False + score_args = pytest_score_arguments() score_args.use_measurement_for_self_influence = use_measurement_for_self_influence - factors_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}" analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}_scores" analyzer.compute_self_scores( - scores_name=scores_name, + scores_name=DEFAULT_SCORES_NAME, score_args=score_args, - factors_name=factors_name, + factors_name=DEFAULT_FACTORS_NAME, train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, overwrite_output_dir=True, ) - scores = analyzer.load_self_scores(scores_name=scores_name) + scores = analyzer.load_self_scores(scores_name=DEFAULT_SCORES_NAME) - factor_args.shared_parameters_exist = True - factors_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}_shared" + factor_args.has_shared_parameters = True analyzer.fit_all_factors( - factors_name=factors_name, + factors_name=custom_factors_name("shared"), factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=8, overwrite_output_dir=True, ) - scores_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}_shared_scores" analyzer.compute_self_scores( - scores_name=scores_name, + scores_name=custom_scores_name("shared"), score_args=score_args, - factors_name=factors_name, + factors_name=custom_factors_name("shared"), train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, overwrite_output_dir=True, ) - shared_scores = analyzer.load_self_scores(scores_name=scores_name) + shared_scores = analyzer.load_self_scores(scores_name=custom_scores_name("shared")) assert check_tensor_dict_equivalence( scores, diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index a66879b..885b6c9 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -13,10 +13,11 @@ "test_name", [ "mlp", + "mlp_checkpoint", "repeated_mlp", - "conv", "conv_bn", "bert", + "roberta", "gpt", ], ) @@ -48,15 +49,15 @@ def test_analyzer( analysis_name=f"pytest_{test_name}", model=model, task=task, - disable_model_save=True, disable_tqdm=True, + disable_model_save=True, cpu=True, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) factor_args = FactorArguments(strategy=strategy) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True analyzer.fit_all_factors( factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", dataset=train_dataset, @@ -85,11 +86,12 @@ def test_analyzer( score_args=score_args, overwrite_output_dir=True, ) + score_args.use_measurement_for_self_influence = True analyzer.compute_self_scores( scores_name="self", factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", train_dataset=train_dataset, - per_device_train_batch_size=8, + per_device_train_batch_size=6, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, @@ -101,23 +103,22 @@ def test_default_factor_arguments() -> None: assert factor_args.strategy == "ekfac" assert factor_args.use_empirical_fisher is False - assert factor_args.distributed_sync_steps == 1000 assert factor_args.amp_dtype is None - assert factor_args.shared_parameters_exist is False + assert factor_args.has_shared_parameters is False assert factor_args.covariance_max_examples == 100_000 - assert factor_args.covariance_data_partition_size == 1 - assert factor_args.covariance_module_partition_size == 1 + assert factor_args.covariance_data_partitions == 1 + assert factor_args.covariance_module_partitions == 1 assert factor_args.activation_covariance_dtype == torch.float32 assert factor_args.gradient_covariance_dtype == torch.float32 assert factor_args.eigendecomposition_dtype == torch.float64 assert factor_args.lambda_max_examples == 100_000 - assert factor_args.lambda_data_partition_size == 1 - assert factor_args.lambda_module_partition_size == 1 - assert factor_args.lambda_iterative_aggregate is False - assert factor_args.cached_activation_cpu_offload is False + assert factor_args.lambda_data_partitions == 1 + assert factor_args.lambda_module_partitions == 1 + assert factor_args.use_iterative_lambda_aggregation is False + assert factor_args.offload_activations_to_cpu is False assert factor_args.per_sample_gradient_dtype == torch.float32 assert factor_args.lambda_dtype == torch.float32 @@ -125,20 +126,25 @@ def test_default_factor_arguments() -> None: def test_default_score_arguments() -> None: score_args = ScoreArguments() - assert score_args.damping == 1e-08 - assert score_args.cached_activation_cpu_offload is False - assert score_args.distributed_sync_steps == 1000 + assert score_args.damping_factor == 1e-08 assert score_args.amp_dtype is None + assert score_args.offload_activations_to_cpu is False + + assert score_args.data_partitions == 1 + assert score_args.module_partitions == 1 + + assert score_args.compute_per_module_scores is False + assert score_args.compute_per_token_scores is False + + assert score_args.query_gradient_accumulation_steps == 1 + assert score_args.query_gradient_low_rank is None + assert score_args.use_full_svd is False + assert score_args.aggregate_query_gradients is False + assert score_args.aggregate_train_gradients is False - assert score_args.data_partition_size == 1 - assert score_args.module_partition_size == 1 - assert score_args.per_module_score is False assert score_args.use_measurement_for_self_influence is False - assert score_args.query_gradient_rank is None - assert score_args.num_query_gradient_accumulations == 1 assert score_args.query_gradient_svd_dtype == torch.float32 - - assert score_args.score_dtype == torch.float32 assert score_args.per_sample_gradient_dtype == torch.float32 assert score_args.precondition_dtype == torch.float32 + assert score_args.score_dtype == torch.float32 diff --git a/tests/test_testable_tasks.py b/tests/test_testable_tasks.py index 722a46c..6d84d75 100644 --- a/tests/test_testable_tasks.py +++ b/tests/test_testable_tasks.py @@ -8,6 +8,7 @@ make_conv_model, ) from tests.testable_tasks.language_modeling import make_gpt_dataset, make_tiny_gpt +from tests.testable_tasks.multiple_choice import make_roberta_dataset, make_tiny_roberta from tests.testable_tasks.regression import make_mlp_model, make_regression_dataset from tests.testable_tasks.text_classification import make_bert_dataset, make_tiny_bert @@ -69,6 +70,27 @@ def test_bert(): logits.sum().backward() +def test_roberta(): + model = make_tiny_roberta(seed=0) + dataset = make_roberta_dataset(num_data=2, seed=0) + batch_size = 2 + loader = DataLoader( + dataset, + collate_fn=default_data_collator, + batch_size=batch_size, + drop_last=False, + shuffle=False, + ) + + batch = next(iter(loader)) + inputs = ( + batch["input_ids"], + batch["attention_mask"], + ) + logits = model(*inputs).logits + logits.sum().backward() + + def test_gpt(): model = make_tiny_gpt(seed=0) dataset = make_gpt_dataset(num_data=8, seed=0) diff --git a/tests/testable_tasks/classification.py b/tests/testable_tasks/classification.py index db89a42..1282f2a 100644 --- a/tests/testable_tasks/classification.py +++ b/tests/testable_tasks/classification.py @@ -26,6 +26,18 @@ def make_conv_model(bias: bool = True, seed: int = 0) -> nn.Module: ) +def make_conv_inplace_model(bias: bool = True, seed: int = 0) -> nn.Module: + set_seed(seed) + return nn.Sequential( + nn.Conv2d(3, 4, 3, 1, bias=bias), + nn.ReLU(inplace=True), + nn.Conv2d(4, 8, 3, 1, bias=bias), + nn.ReLU(inplace=True), + nn.Flatten(), + nn.Linear(1152, 5, bias=bias), + ) + + def make_conv_bn_model(bias: bool = True, seed: int = 0) -> nn.Module: set_seed(seed) return nn.Sequential( diff --git a/tests/testable_tasks/language_modeling.py b/tests/testable_tasks/language_modeling.py index d954337..4c7f1e3 100644 --- a/tests/testable_tasks/language_modeling.py +++ b/tests/testable_tasks/language_modeling.py @@ -1,20 +1,28 @@ # pylint: skip-file from itertools import chain -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import torch import torch.nn.functional as F from datasets import load_dataset from torch import nn from torch.utils import data -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Conv1D +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + Conv1D, + logging, +) from kronfluence.task import Task +logging.set_verbosity_error() BATCH_TYPE = Dict[str, torch.Tensor] +@torch.no_grad() def _replace_conv1d_modules(model: nn.Module) -> None: for name, module in model.named_children(): if len(list(module.children())) > 0: @@ -69,10 +77,6 @@ def tokenize_function(examples): def group_texts(examples): concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - # Just to make sure attention_mask is correctly implemented. - concatenated_examples["attention_mask"][15:30] = [0 for _ in range(15)] - concatenated_examples["attention_mask"][90:100] = [0 for _ in range(10)] - concatenated_examples["attention_mask"][300:400] = [0 for _ in range(100)] total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= block_size: total_length = (total_length // block_size) * block_size @@ -133,7 +137,7 @@ def compute_measurement( ) -> torch.Tensor: return self.compute_train_loss(batch, model) - def tracked_modules(self) -> List[str]: + def get_influence_tracked_modules(self) -> List[str]: total_modules = [] for i in range(5): @@ -146,5 +150,5 @@ def tracked_modules(self) -> List[str]: return total_modules - def get_attention_mask(self, batch: Any) -> Optional[torch.Tensor]: + def get_attention_mask(self, batch: Any) -> torch.Tensor: return batch["attention_mask"] diff --git a/tests/testable_tasks/multiple_choice.py b/tests/testable_tasks/multiple_choice.py new file mode 100644 index 0000000..971b6fb --- /dev/null +++ b/tests/testable_tasks/multiple_choice.py @@ -0,0 +1,133 @@ +# pylint: skip-file + +from itertools import chain +from typing import Any, Dict + +import torch +import torch.nn.functional as F +from datasets import load_dataset +from torch import nn +from torch.utils import data +from transformers import AutoConfig, AutoModelForMultipleChoice, AutoTokenizer, logging + +from kronfluence.task import Task + +logging.set_verbosity_error() +BATCH_TYPE = Dict[str, torch.Tensor] + + +def make_tiny_roberta(seed: int = 0) -> nn.Module: + torch.manual_seed(seed) + config = AutoConfig.from_pretrained( + "hf-internal-testing/tiny-random-RobertaModel", + trust_remote_code=True, + ) + model = AutoModelForMultipleChoice.from_pretrained( + "hf-internal-testing/tiny-random-RobertaModel", + from_tf=False, + config=config, + ignore_mismatched_sizes=False, + trust_remote_code=True, + ) + return model + + +def make_roberta_dataset(num_data: int, seed: int = 0) -> data.Dataset: + torch.manual_seed(seed) + raw_datasets = load_dataset("swag", "regular") + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-RobertaModel", use_fast=True, trust_remote_code=True + ) + + column_names = raw_datasets["train"].column_names + ending_names = [f"ending{i}" for i in range(4)] + context_name = "sent1" + question_header_name = "sent2" + label_column_name = "label" if "label" in column_names else "labels" + padding = "max_length" + + def preprocess_function(examples: Any): + first_sentences = [[context] * 4 for context in examples[context_name]] + question_headers = examples[question_header_name] + second_sentences = [ + [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers) + ] + labels = examples[label_column_name] + + first_sentences = list(chain(*first_sentences)) + second_sentences = list(chain(*second_sentences)) + + tokenized_examples = tokenizer( + first_sentences, + second_sentences, + max_length=128, + padding=padding, + truncation=True, + ) + tokenized_inputs = {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} + tokenized_inputs["labels"] = labels + return tokenized_inputs + + processed_datasets = raw_datasets.map( + preprocess_function, + batched=True, + remove_columns=raw_datasets["train"].column_names, + ) + + dataset = processed_datasets["train"].select(range(num_data)) + + return dataset + + +class MultipleChoiceTask(Task): + enable_post_process_per_sample_gradient = True + + def compute_train_loss( + self, + batch: BATCH_TYPE, + model: nn.Module, + sample: bool = False, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits + + if not sample: + return F.cross_entropy(logits, batch["labels"], reduction="sum") + with torch.no_grad(): + probs = torch.nn.functional.softmax(logits.detach(), dim=-1) + sampled_labels = torch.multinomial( + probs, + num_samples=1, + ).flatten() + return F.cross_entropy(logits, sampled_labels, reduction="sum") + + def compute_measurement( + self, + batch: BATCH_TYPE, + model: nn.Module, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits + + labels = batch["labels"] + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) + logits_correct = logits[bindex, labels] + + cloned_logits = logits.clone() + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) + + margins = logits_correct - cloned_logits.logsumexp(dim=-1) + return -margins.sum() + + def get_attention_mask(self, batch: Any) -> torch.Tensor: + return batch["attention_mask"] + + def post_process_per_sample_gradient(self, module_name: str, gradient: torch.Tensor) -> torch.Tensor: + del module_name + total_batch_size = gradient.size(0) + true_batch_size = int(total_batch_size / 4) + return gradient.reshape(true_batch_size, 4, *gradient.size()[1:]).sum(dim=1) diff --git a/tests/testable_tasks/text_classification.py b/tests/testable_tasks/text_classification.py index 0a2ac5f..6e9c005 100644 --- a/tests/testable_tasks/text_classification.py +++ b/tests/testable_tasks/text_classification.py @@ -7,10 +7,16 @@ from datasets import load_dataset from torch import nn from torch.utils import data -from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + logging, +) from kronfluence.task import Task +logging.set_verbosity_error() BATCH_TYPE = Dict[str, torch.Tensor] @@ -113,7 +119,7 @@ def compute_measurement( margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() - def get_attention_mask(self, batch: Any) -> Optional[torch.Tensor]: + def get_attention_mask(self, batch: Any) -> torch.Tensor: return batch["attention_mask"] diff --git a/tests/utils.py b/tests/utils.py index 226f679..03ca290 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ WrongClassificationTask, make_classification_dataset, make_conv_bn_model, + make_conv_inplace_model, make_conv_model, ) from tests.testable_tasks.language_modeling import ( @@ -22,6 +23,11 @@ make_gpt_dataset, make_tiny_gpt, ) +from tests.testable_tasks.multiple_choice import ( + MultipleChoiceTask, + make_roberta_dataset, + make_tiny_roberta, +) from tests.testable_tasks.regression import ( GradientCheckpointRegressionTask, RegressionTask, @@ -39,15 +45,25 @@ RTOL = 1.3e-6 ATOL = 1e-5 +DEFAULT_FACTORS_NAME = "pytest" +DEFAULT_SCORES_NAME = "pytest" + + +def custom_factors_name(name: str) -> str: + return f"{DEFAULT_FACTORS_NAME}_{name}" + + +def custom_scores_name(name: str) -> str: + return f"{DEFAULT_FACTORS_NAME}_{name}" + def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: model = prepare_model(model=model, task=task) analyzer = Analyzer( - analysis_name=f"pytest_{__name__}", + analysis_name="pytest", model=model, task=task, disable_model_save=True, - cpu=True, disable_tqdm=True, ) return model, analyzer @@ -84,6 +100,12 @@ def prepare_test( query_dataset = make_classification_dataset(num_data=query_size, seed=seed + 1) task = ClassificationTask() data_collator = None + elif test_name == "conv_inplace": + model = make_conv_inplace_model(seed=seed) + train_dataset = make_classification_dataset(num_data=train_size, seed=seed) + query_dataset = make_classification_dataset(num_data=query_size, seed=seed + 1) + task = ClassificationTask() + data_collator = None elif test_name == "wrong_conv": model = make_conv_model(seed=seed) train_dataset = make_classification_dataset(num_data=train_size, seed=seed) @@ -108,12 +130,25 @@ def prepare_test( query_dataset = make_bert_dataset(num_data=query_size, seed=seed + 1, do_not_pad=do_not_pad) task = WrongTextClassificationTask() data_collator = default_data_collator + elif test_name == "roberta": + model = make_tiny_roberta(seed=seed) + train_dataset = make_roberta_dataset(num_data=train_size, seed=seed) + query_dataset = make_roberta_dataset(num_data=query_size, seed=seed + 1) + task = MultipleChoiceTask() + data_collator = default_data_collator elif test_name == "gpt": model = make_tiny_gpt(seed=seed) train_dataset = make_gpt_dataset(num_data=train_size, seed=seed) query_dataset = make_gpt_dataset(num_data=query_size, seed=seed + 1) task = LanguageModelingTask() data_collator = default_data_collator + elif test_name == "gpt_checkpoint": + model = make_tiny_gpt(seed=seed) + model.gradient_checkpointing_enable() + train_dataset = make_gpt_dataset(num_data=train_size, seed=seed) + query_dataset = make_gpt_dataset(num_data=query_size, seed=seed + 1) + task = LanguageModelingTask() + data_collator = default_data_collator else: raise NotImplementedError(f"{test_name} is not a valid test configuration name.") model.eval() @@ -172,6 +207,6 @@ def reshape_parameter_gradient_to_module_matrix( if remove_gradient: del gradient_dict[module_name + ".bias"] else: - error_msg = f"Unsupported module type: {type(module)}. Only nn.Linear or nn.Conv2d are supported." + error_msg = f"Unsupported module type: {type(module)}. Only `nn.Linear` or `nn.Conv2d` are supported." raise UnsupportableModuleError(error_msg) return gradient_matrix