Skip to content

Commit

Permalink
Finalize factor computations
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 6, 2024
1 parent d0154f1 commit a80595e
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 80 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ jobs:
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/factors/test_eigendecompositions.py
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_svd.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
1 change: 1 addition & 0 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def perform_eigendecomposition(
output_dir=factors_output_dir, factors=eigen_factors, metadata=factor_args.to_str_dict()
)
self.logger.info(f"Saved eigendecomposition results at `{factors_output_dir}`.")
del eigen_factors
self._reset_memory()
self.state.wait_for_everyone()
self._log_profile_summary(name=f"factors_{factors_name}_eigendecomposition")
Expand Down
2 changes: 1 addition & 1 deletion kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def fit_lambda_matrices_with_loader(
sample=not factor_args.use_empirical_fisher,
)
scaler.scale(loss).backward()
del loss

if factor_args.has_shared_parameters:
finalize_iteration(model=model, tracked_module_names=tracked_module_names)
Expand All @@ -432,6 +431,7 @@ def fit_lambda_matrices_with_loader(
state.wait_for_everyone()

num_data_processed.add_(find_batch_size(data=batch))
del batch, loss
total_steps += 1
pbar.update(1)

Expand Down
26 changes: 16 additions & 10 deletions kronfluence/module/tracker/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N

if self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None:
# In most cases, `NUM_GRADIENT_COVARIANCE_PROCESSED` and `NUM_ACTIVATION_COVARIANCE_PROCESSED` are
# identical. However, they may differ when using gradient checkpointing or torch.compile().
# identical. However, they may differ when using gradient checkpointing or `torch.compile()`.
self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros(
size=(1,),
dtype=torch.int64,
Expand All @@ -85,19 +85,22 @@ def register_hooks(self) -> None:
def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None:
del module
with torch.no_grad():
# Computes and updates activation covariance during forward pass.
input_activation = (
inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype,
copy=self.module.attention_mask is not None)
inputs[0]
.detach()
.to(
dtype=self.module.factor_args.activation_covariance_dtype,
copy=self.module.attention_mask is not None,
)
)
# Computes and updates activation covariance during forward pass.
self._update_activation_covariance_matrix(input_activation=input_activation)
self.cached_hooks.append(outputs.register_hook(backward_hook))

@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
# Computes and updates pseudo-gradient covariance during backward pass.
original_dtype = output_gradient.dtype
target_dtype = self.module.factor_args.gradient_covariance_dtype
output_gradient = output_gradient.detach().to(dtype=target_dtype)
Expand All @@ -106,6 +109,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale
# Computes and updates pseudo-gradient covariance during backward pass.
self._update_gradient_covariance_matrix(output_gradient=output_gradient)

self.registered_hooks.append(self.module.register_forward_hook(forward_hook))
Expand Down Expand Up @@ -245,7 +249,6 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
def backward_hook(output_gradient: torch.Tensor) -> None:
if self.cached_activations is None:
self._raise_cache_not_found_exception()

handle = self.cached_hooks.pop()
handle.remove()
original_dtype = output_gradient.dtype
Expand All @@ -256,11 +259,13 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale
self.cached_activations = self.cached_activations.to(device=output_gradient.device)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
input_activation=self.cached_activations,
output_gradient=output_gradient,
).to(dtype=self.module.factor_args.lambda_dtype)
self.clear_all_cache()
# Computes and updates lambda matrix during backward pass.
self._update_lambda_matrix(per_sample_gradient=per_sample_gradient)

@torch.no_grad()
Expand All @@ -276,15 +281,17 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None:
else:
output_gradient = output_gradient * self.module.gradient_scale
cached_activation = self.cached_activations.pop()
cached_activation = cached_activation.to(device=output_gradient.device)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
input_activation=cached_activation,
output_gradient=output_gradient,
)
if self.cached_per_sample_gradient is None:
self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False)
# Aggregates per-sample gradients during backward pass.
self.cached_per_sample_gradient.add_(per_sample_gradient)

self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook))
self.registered_hooks.append(self.module.register_forward_hook(forward_hook))

@torch.no_grad()
def finalize_iteration(self) -> None:
Expand Down Expand Up @@ -319,5 +326,4 @@ def release_memory(self) -> None:
"""Clears Lambda matrices from memory."""
self.clear_all_cache()
for lambda_factor_name in LAMBDA_FACTOR_NAMES:
del self.module.storage[lambda_factor_name]
self.module.storage[lambda_factor_name] = None
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ disable = """
implicit-str-concat,
inconsistent-return-statements,
too-many-lines,
too-many-public-methods,
"""
71 changes: 5 additions & 66 deletions tests/factors/test_covariances.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
"repeated_mlp",
"conv",
"bert",
"roberta",
"gpt",
"gpt_checkpoint",
],
)
@pytest.mark.parametrize("activation_covariance_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("gradient_covariance_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("activation_covariance_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("gradient_covariance_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("train_size", [16])
@pytest.mark.parametrize("seed", [0])
def test_fit_covariance_matrices(
Expand All @@ -59,7 +60,6 @@ def test_fit_covariance_matrices(
model=model,
task=task,
)

factor_args = default_factor_arguments()
factor_args.activation_covariance_dtype = activation_covariance_dtype
factor_args.gradient_covariance_dtype = gradient_covariance_dtype
Expand Down Expand Up @@ -387,70 +387,9 @@ def test_covariance_matrices_max_examples(
assert num_examples == max_examples


@pytest.mark.parametrize(
"test_name",
[
"mlp",
"conv_bn",
],
)
@pytest.mark.parametrize("train_size", [101])
@pytest.mark.parametrize("seed", [8])
def test_covariance_matrices_amp(
test_name: str,
train_size: int,
seed: int,
) -> None:
# Covariance matrices should be similar even when AMP is enabled.
model, train_dataset, _, data_collator, task = prepare_test(
test_name=test_name,
train_size=train_size,
seed=seed,
)
kwargs = DataLoaderKwargs(collate_fn=data_collator)
model, analyzer = prepare_model_and_analyzer(
model=model,
task=task,
)

factor_args = pytest_factor_arguments()
analyzer.fit_covariance_matrices(
factors_name=DEFAULT_FACTORS_NAME,
dataset=train_dataset,
per_device_batch_size=8,
overwrite_output_dir=True,
factor_args=factor_args,
dataloader_kwargs=kwargs,
)
covariance_factors = analyzer.load_covariance_matrices(
factors_name=DEFAULT_FACTORS_NAME,
)

factor_args.amp_dtype = torch.float16
analyzer.fit_covariance_matrices(
factors_name=custom_factors_name("amp"),
dataset=train_dataset,
per_device_batch_size=8,
overwrite_output_dir=True,
factor_args=factor_args,
dataloader_kwargs=kwargs,
)
amp_covariance_factors = analyzer.load_covariance_matrices(
factors_name=custom_factors_name("amp"),
)

for name in COVARIANCE_FACTOR_NAMES:
assert check_tensor_dict_equivalence(
covariance_factors[name],
amp_covariance_factors[name],
atol=1e-01,
rtol=1e-02,
)


@pytest.mark.parametrize("test_name", ["mlp", "gpt"])
@pytest.mark.parametrize("train_size", [100])
@pytest.mark.parametrize("seed", [7])
@pytest.mark.parametrize("seed", [6])
def test_covariance_matrices_gradient_checkpoint(
test_name: str,
train_size: int,
Expand Down Expand Up @@ -514,7 +453,7 @@ def test_covariance_matrices_gradient_checkpoint(


@pytest.mark.parametrize("train_size", [100])
@pytest.mark.parametrize("seed", [8, 9])
@pytest.mark.parametrize("seed", [7, 8])
def test_covariance_matrices_inplace(
train_size: int,
seed: int,
Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module,
model=model,
task=task,
disable_model_save=True,
# cpu=True,
disable_tqdm=True,
)
return model, analyzer
Expand Down

0 comments on commit a80595e

Please sign in to comment.