Skip to content

Commit d7e35bc

Browse files
authored
Merge pull request #29 from pomonam/amp_logic
Disable GradScaler
2 parents 52a67f9 + 3f397cf commit d7e35bc

File tree

15 files changed

+102
-77
lines changed

15 files changed

+102
-77
lines changed

DOCUMENTATION.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ factor_args = FactorArguments(
205205
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
206206
use_empirical_fisher=False,
207207
amp_dtype=None,
208+
amp_scale=2.0**16,
208209
has_shared_parameters=False,
209210

210211
# Settings for covariance matrix fitting.
@@ -236,6 +237,7 @@ You can change:
236237
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
237238
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
238239
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
240+
- `amp_scale`: Sets the scale factor for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html).
239241
- `has_shared_parameters`: Specifies whether the shared parameters exist in the forward pass.
240242

241243
### Fitting Covariance Matrices

examples/cifar/inspect_factors.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import logging
2+
3+
import matplotlib.pyplot as plt
4+
5+
from kronfluence.analyzer import Analyzer
6+
7+
8+
def main():
9+
logging.basicConfig(level=logging.INFO)
10+
11+
name = "ekfac"
12+
factor = (
13+
Analyzer.load_file(f"influence_results/cifar10/factors_{name}/activation_covariance.safetensors")
14+
)
15+
16+
plt.matshow(factor["6.0"])
17+
plt.show()
18+
19+
factor = (
20+
Analyzer.load_file(f"influence_results/cifar10/factors_{name}/gradient_covariance.safetensors")
21+
)
22+
23+
plt.matshow(factor["6.0"])
24+
plt.show()
25+
26+
27+
if __name__ == "__main__":
28+
main()

examples/openwebtext/compute_scores.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from kronfluence.utils.common.factor_arguments import (
2121
extreme_reduce_memory_factor_arguments,
2222
)
23-
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments, \
24-
extreme_reduce_memory_score_arguments
23+
from kronfluence.utils.common.score_arguments import (
24+
all_low_precision_score_arguments,
25+
extreme_reduce_memory_score_arguments,
26+
)
2527
from kronfluence.utils.dataset import DataLoaderKwargs
2628

2729
BATCH_TYPE = Dict[str, torch.Tensor]

kronfluence/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class FactorArguments(Arguments):
5555
default=None,
5656
metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."},
5757
)
58+
amp_scale: float = field(
59+
default=2.0**16,
60+
metadata={"help": "Scale factor for AMP (only applicable when `amp_dtype=torch.float16`)."},
61+
)
5862
has_shared_parameters: bool = field(
5963
default=False,
6064
metadata={"help": "Indicates whether shared parameters are present in the model's forward pass."},

kronfluence/factor/covariance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,9 @@ def fit_covariance_matrices_with_loader(
196196
total_steps = 0
197197
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
198198
enable_amp = factor_args.amp_dtype is not None
199-
scaler = GradScaler(enabled=enable_amp)
200-
if enable_amp:
199+
enable_grad_scaler = enable_amp and factor_args.amp_dtype == torch.float16
200+
scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
201+
if enable_grad_scaler:
201202
gradient_scale = 1.0 / scaler.get_scale()
202203
set_gradient_scale(model=model, gradient_scale=gradient_scale)
203204

@@ -257,7 +258,7 @@ def fit_covariance_matrices_with_loader(
257258

258259
model.zero_grad(set_to_none=True)
259260
set_attention_mask(model=model, attention_mask=None)
260-
if enable_amp:
261+
if enable_grad_scaler:
261262
set_gradient_scale(model=model, gradient_scale=1.0)
262263
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
263264
state.wait_for_everyone()

kronfluence/factor/eigen.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,9 @@ def fit_lambda_matrices_with_loader(
394394
total_steps = 0
395395
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
396396
enable_amp = factor_args.amp_dtype is not None
397-
scaler = GradScaler(enabled=enable_amp)
398-
if enable_amp:
397+
enable_grad_scaler = enable_amp and factor_args.amp_dtype == torch.float16
398+
scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
399+
if enable_grad_scaler:
399400
gradient_scale = 1.0 / scaler.get_scale()
400401
set_gradient_scale(model=model, gradient_scale=gradient_scale)
401402

@@ -453,7 +454,7 @@ def fit_lambda_matrices_with_loader(
453454
saved_factors[factor_name] = factor
454455

455456
model.zero_grad(set_to_none=True)
456-
if enable_amp:
457+
if enable_grad_scaler:
457458
set_gradient_scale(model=model, gradient_scale=1.0)
458459
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
459460
state.wait_for_everyone()

kronfluence/module/tracker/base.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,6 @@ def _raise_cache_not_found_exception(self) -> None:
4747
f"For case 2, set 'has_shared_parameters=True' to enable parameter sharing."
4848
)
4949

50-
def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor:
51-
"""Preprocesses the output gradient.
52-
53-
Args:
54-
output_gradient (torch.Tensor):
55-
The original output gradient.
56-
target_dtype (torch.dtype):
57-
The desired data type for the gradient tensor.
58-
59-
Returns:
60-
torch.Tensor:
61-
The preprocessed gradient.
62-
"""
63-
original_dtype = output_gradient.dtype
64-
output_gradient = output_gradient.to(dtype=target_dtype)
65-
if self.module.gradient_scale != 1.0:
66-
if original_dtype != target_dtype:
67-
output_gradient.mul_(self.module.gradient_scale)
68-
else:
69-
output_gradient = output_gradient * self.module.gradient_scale
70-
return output_gradient
71-
7250
def register_hooks(self) -> None:
7351
"""Registers hooks for the module."""
7452

kronfluence/module/tracker/factor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def _update_gradient_covariance_matrix(
8787
)
8888
self._gradient_covariance_initialized = True
8989
self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count)
90-
self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient)
90+
alpha = 1
91+
if self.module.gradient_scale != 1.0:
92+
alpha = self.module.gradient_scale**2.0
93+
self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient, alpha=alpha)
9194

9295
def register_hooks(self) -> None:
9396
"""Sets up hooks to compute activation and gradient covariance matrices."""
@@ -112,9 +115,7 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
112115
def backward_hook(output_gradient: torch.Tensor) -> None:
113116
handle = self.cached_hooks.pop()
114117
handle.remove()
115-
output_gradient = self._preprocess_gradient(
116-
output_gradient.detach(), target_dtype=self.module.factor_args.gradient_covariance_dtype
117-
)
118+
output_gradient = output_gradient.detach().to(dtype=self.module.factor_args.gradient_covariance_dtype)
118119
# Computes and updates pseudo-gradient covariance during backward pass.
119120
output_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient)
120121
self._update_gradient_covariance_matrix(output_gradient=output_gradient, count=count)
@@ -259,25 +260,23 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
259260
self._raise_cache_not_found_exception()
260261
handle = self.cached_hooks.pop()
261262
handle.remove()
262-
output_gradient = self._preprocess_gradient(
263-
output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype
264-
)
263+
output_gradient = output_gradient.detach().to(dtype=self.module.factor_args.per_sample_gradient_dtype)
265264
per_sample_gradient = self.module.compute_per_sample_gradient(
266265
input_activation=self.cached_activations.to(device=output_gradient.device),
267266
output_gradient=output_gradient,
268267
).to(dtype=self.module.factor_args.lambda_dtype)
269268
self.clear_all_cache()
270269
del output_gradient
270+
if self.module.gradient_scale != 1.0:
271+
per_sample_gradient.mul_(self.module.gradient_scale)
271272
# Computes and updates lambda matrix during backward pass.
272273
self._update_lambda_matrix(per_sample_gradient=per_sample_gradient)
273274

274275
@torch.no_grad()
275276
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
276277
handle = self.cached_hooks.pop()
277278
handle.remove()
278-
output_gradient = self._preprocess_gradient(
279-
output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype
280-
)
279+
output_gradient = output_gradient.detach().to(dtype=self.module.factor_args.per_sample_gradient_dtype)
281280
cached_activation = self.cached_activations.pop()
282281
per_sample_gradient = self.module.compute_per_sample_gradient(
283282
input_activation=cached_activation.to(device=output_gradient.device),
@@ -297,6 +296,8 @@ def finalize_iteration(self) -> None:
297296
self.cached_per_sample_gradient = self.cached_per_sample_gradient.to(
298297
dtype=self.module.factor_args.lambda_dtype
299298
)
299+
if self.module.gradient_scale != 1.0:
300+
self.cached_per_sample_gradient.mul_(self.module.gradient_scale)
300301
self._update_lambda_matrix(per_sample_gradient=self.cached_per_sample_gradient)
301302
self.clear_all_cache()
302303

kronfluence/module/tracker/gradient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
3838
self._raise_cache_not_found_exception()
3939
handle = self.cached_hooks.pop()
4040
handle.remove()
41-
output_gradient = self._preprocess_gradient(
42-
output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
43-
)
41+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
4442
if isinstance(self.cached_activations, list):
4543
cached_activation = self.cached_activations.pop()
4644
else:
@@ -56,6 +54,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
5654
input_activation=cached_activation.to(device=output_gradient.device),
5755
output_gradient=output_gradient,
5856
).sum(dim=0, keepdim=True)
57+
if self.module.gradient_scale != 1.0:
58+
summed_gradient.mul_(self.module.gradient_scale)
5959
if self.module.storage[AGGREGATED_GRADIENT_NAME] is None:
6060
self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False)
6161
self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient)

kronfluence/module/tracker/pairwise_score.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
7676
self._raise_cache_not_found_exception()
7777
handle = self.cached_hooks.pop()
7878
handle.remove()
79-
output_gradient = self._preprocess_gradient(
80-
output_gradient.detach(), target_dtype=self.module.score_args.score_dtype
81-
)
79+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.score_dtype)
8280
if isinstance(self.cached_activations, list):
8381
cached_activation = self.cached_activations.pop()
8482
else:
@@ -90,6 +88,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
9088
input_activation=cached_activation.to(device=output_gradient.device),
9189
output_gradient=output_gradient,
9290
)
91+
if self.module.gradient_scale != 1.0:
92+
self.module.storage[PAIRWISE_SCORE_MATRIX_NAME].mul_(self.module.gradient_scale)
9393
del cached_activation, output_gradient
9494
self.clear_all_cache()
9595
else:
@@ -98,6 +98,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
9898
output_gradient=output_gradient,
9999
)
100100
del cached_activation, output_gradient
101+
if self.module.gradient_scale != 1.0:
102+
per_sample_gradient.mul_(self.module.gradient_scale)
101103
self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient)
102104

103105
self.registered_hooks.append(self.module.register_forward_hook(forward_hook))

kronfluence/module/tracker/precondition.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
105105
self._raise_cache_not_found_exception()
106106
handle = self.cached_hooks.pop()
107107
handle.remove()
108-
output_gradient = self._preprocess_gradient(
109-
output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
110-
)
108+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
111109
per_sample_gradient = self.module.compute_per_sample_gradient(
112110
input_activation=self.cached_activations.to(device=output_gradient.device),
113111
output_gradient=output_gradient,
@@ -119,16 +117,16 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
119117
gradient=per_sample_gradient,
120118
storage=self.module.storage,
121119
)
120+
if self.module.gradient_scale != 1.0:
121+
preconditioned_gradient.mul_(self.module.gradient_scale)
122122
del per_sample_gradient
123123
self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient)
124124

125125
@torch.no_grad()
126126
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
127127
handle = self.cached_hooks.pop()
128128
handle.remove()
129-
output_gradient = self._preprocess_gradient(
130-
output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
131-
)
129+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
132130
cached_activation = self.cached_activations.pop()
133131
per_sample_gradient = self.module.compute_per_sample_gradient(
134132
input_activation=cached_activation.to(device=output_gradient.device),
@@ -153,6 +151,8 @@ def finalize_iteration(self) -> None:
153151
storage=self.module.storage,
154152
)
155153
self.cached_per_sample_gradient = None
154+
if self.module.gradient_scale != 1.0:
155+
preconditioned_gradient.mul_(self.module.gradient_scale)
156156
self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient)
157157
self.clear_all_cache()
158158

kronfluence/module/tracker/self_score.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,22 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
9191
self._raise_cache_not_found_exception()
9292
handle = self.cached_hooks.pop()
9393
handle.remove()
94-
output_gradient = self._preprocess_gradient(
95-
output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
96-
)
94+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
9795
per_sample_gradient = self.module.compute_per_sample_gradient(
9896
input_activation=self.cached_activations.to(device=output_gradient.device),
9997
output_gradient=output_gradient,
10098
).to(dtype=self.module.score_args.precondition_dtype)
10199
self.clear_all_cache()
102100
del output_gradient
101+
if self.module.gradient_scale != 1.0:
102+
per_sample_gradient.mul_(self.module.gradient_scale)
103103
self._compute_self_score(per_sample_gradient=per_sample_gradient)
104104

105105
@torch.no_grad()
106106
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
107107
handle = self.cached_hooks.pop()
108108
handle.remove()
109-
output_gradient = self._preprocess_gradient(
110-
output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
111-
)
109+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
112110
cached_activation = self.cached_activations.pop()
113111
per_sample_gradient = self.module.compute_per_sample_gradient(
114112
input_activation=cached_activation.to(device=output_gradient.device),
@@ -127,6 +125,8 @@ def finalize_iteration(self) -> None:
127125
self.cached_per_sample_gradient = self.cached_per_sample_gradient.to(
128126
dtype=self.module.score_args.precondition_dtype
129127
)
128+
if self.module.gradient_scale != 1.0:
129+
self.cached_per_sample_gradient.mul_(self.module.gradient_scale)
130130
self._compute_self_score(per_sample_gradient=self.cached_per_sample_gradient)
131131
self.clear_all_cache()
132132

@@ -202,9 +202,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
202202

203203
handle = self.cached_hooks.pop()
204204
handle.remove()
205-
output_gradient = self._preprocess_gradient(
206-
output_gradient.detach(), target_dtype=self.module.score_args.score_dtype
207-
)
205+
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.score_dtype)
208206
if isinstance(self.cached_activations, list):
209207
cached_activation = self.cached_activations.pop()
210208
else:
@@ -217,6 +215,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
217215
)
218216
self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None
219217
self.clear_all_cache()
218+
if self.module.gradient_scale != 1.0:
219+
scores.mul_(self.module.gradient_scale)
220220
if self.module.storage[SELF_SCORE_VECTOR_NAME] is None:
221221
self.module.storage[SELF_SCORE_VECTOR_NAME] = scores
222222
else:
@@ -227,6 +227,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
227227
output_gradient=output_gradient,
228228
)
229229
del cached_activation, output_gradient
230+
if self.module.gradient_scale != 1.0:
231+
per_sample_gradient.mul_(self.module.gradient_scale)
230232
self._compute_self_measurement_score_with_gradient(per_sample_gradient=per_sample_gradient)
231233

232234
self.registered_hooks.append(self.module.register_forward_hook(forward_hook))

0 commit comments

Comments
 (0)