|
1 | 1 | from typing import Optional
|
2 | 2 |
|
3 | 3 | from accelerate.utils import extract_model_from_parallel
|
4 |
| -from factor.config import FactorConfig |
5 | 4 | from safetensors.torch import save_file
|
6 | 5 | from torch import nn
|
7 | 6 | from torch.utils import data
|
|
11 | 10 | from kronfluence.computer.eigen_computer import EigenComputer
|
12 | 11 | from kronfluence.computer.pairwise_score_computer import PairwiseScoreComputer
|
13 | 12 | from kronfluence.computer.self_score_computer import SelfScoreComputer
|
14 |
| -from kronfluence.module.constants import FACTOR_TYPE |
15 | 13 | from kronfluence.module.utils import wrap_tracked_modules
|
16 | 14 | from kronfluence.task import Task
|
17 | 15 | from kronfluence.utils.dataset import DataLoaderKwargs
|
@@ -121,7 +119,7 @@ def fit_all_factors(
|
121 | 119 | dataloader_kwargs: Optional[DataLoaderKwargs] = None,
|
122 | 120 | factor_args: Optional[FactorArguments] = None,
|
123 | 121 | overwrite_output_dir: bool = False,
|
124 |
| - ) -> Optional[FACTOR_TYPE]: |
| 122 | + ) -> None: |
125 | 123 | """Computes all necessary factors for the given factor strategy. As an example, EK-FAC
|
126 | 124 | requires (1) computing covariance matrices, (2) performing Eigendecomposition, and
|
127 | 125 | (3) computing Lambda (corrected-eigenvalues) matrices.
|
@@ -163,11 +161,3 @@ def fit_all_factors(
|
163 | 161 | factor_args=factor_args,
|
164 | 162 | overwrite_output_dir=overwrite_output_dir,
|
165 | 163 | )
|
166 |
| - |
167 |
| - if factor_args is None: |
168 |
| - factor_args = FactorArguments() |
169 |
| - strategy = factor_args.strategy |
170 |
| - factor_config = FactorConfig.CONFIGS[strategy] |
171 |
| - return self._load_all_required_factors( |
172 |
| - factors_name=factors_name, strategy=strategy, factor_config=factor_config |
173 |
| - ) |
0 commit comments