Skip to content

Commit

Permalink
Clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 29, 2024
1 parent 2569fbb commit db1c89f
Show file tree
Hide file tree
Showing 13 changed files with 291 additions and 233 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
scikit-learn>=1.4.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
Expand Down
7 changes: 7 additions & 0 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ class ScoreArguments(Arguments):
default=None,
metadata={"help": "Rank for the query gradient. Does not apply low-rank approximation if None."},
)
use_full_svd: bool = field(
default=True,
metadata={
"help": "Whether to perform to use `torch.linalg.svd` instead of `torch.svd_lowrank` for "
"query batching. `torch.svd_lowrank` can result in a more inaccurate low-rank approximations."
},
)
use_measurement_for_self_influence: bool = field(
default=False,
metadata={"help": "Whether to use the measurement (instead of the loss) for computing self-influence scores."},
Expand Down
2 changes: 0 additions & 2 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
save_eigendecomposition,
save_lambda_matrices,
)
from kronfluence.module.tracked_module import ModuleMode
from kronfluence.module.utils import set_mode
from kronfluence.utils.constants import FACTOR_TYPE
from kronfluence.utils.dataset import DataLoaderKwargs, find_executable_batch_size
from kronfluence.utils.exceptions import FactorsNotFoundError
Expand Down
19 changes: 7 additions & 12 deletions kronfluence/computer/score_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from kronfluence.utils.exceptions import FactorsNotFoundError
from kronfluence.utils.logger import get_time
from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME, SCORE_ARGUMENTS_NAME
from kronfluence.utils.state import release_memory


class ScoreComputer(Computer):
Expand Down Expand Up @@ -157,7 +156,7 @@ def _find_executable_pairwise_scores_batch_size(
if self.state.use_distributed:
error_msg = (
"Automatic batch size search is currently not supported for multi-GPU training. "
"Please manually configure the batch size by passing in `per_device_train_batch_size`."
"Please manually configure the batch size by passing in `per_device_batch_size`."
)
self.logger.error(error_msg)
raise NotImplementedError(error_msg)
Expand All @@ -174,9 +173,7 @@ def _find_executable_pairwise_scores_batch_size(
def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Releases all memory that could be caused by the previous OOM.
self.model.zero_grad(set_to_none=True)
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
self._reset_memory()
total_batch_size = batch_size * self.state.num_processes
query_loader = self._get_dataloader(
dataset=query_dataset,
Expand Down Expand Up @@ -377,7 +374,7 @@ def compute_pairwise_scores(
tracked_modules_name=module_partition_names[module_partition],
)

release_memory()
self._reset_memory()
start_time = get_time(state=self.state)
with self.profiler.profile("Compute Pairwise Score"):
query_loader = self._get_dataloader(
Expand Down Expand Up @@ -431,7 +428,7 @@ def compute_pairwise_scores(
self.aggregate_pairwise_scores(scores_name=scores_name)
self.logger.info(f"Saved aggregated pairwise scores at `{scores_output_dir}`.")
self.state.wait_for_everyone()
self._log_profile_summary()
self._log_profile_summary(name=f"scores_{scores_name}_pairwise")

@torch.no_grad()
def aggregate_pairwise_scores(self, scores_name: str) -> None:
Expand Down Expand Up @@ -491,9 +488,7 @@ def _find_executable_self_scores_batch_size(
def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Releases all memory that could be caused by the previous OOM.
self.model.zero_grad(set_to_none=True)
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
self._reset_memory()
total_batch_size = batch_size * self.state.num_processes
train_loader = self._get_dataloader(
dataset=train_dataset,
Expand Down Expand Up @@ -672,7 +667,7 @@ def compute_self_scores(
tracked_modules_name=module_partition_names[module_partition],
)

release_memory()
self._reset_memory()
start_time = get_time(state=self.state)
with self.profiler.profile("Compute Self-Influence Score"):
train_loader = self._get_dataloader(
Expand Down Expand Up @@ -722,7 +717,7 @@ def compute_self_scores(
self.aggregate_self_scores(scores_name=scores_name)
self.logger.info(f"Saved aggregated self-influence scores at `{scores_output_dir}`.")
self.state.wait_for_everyone()
self._log_profile_summary()
self._log_profile_summary(name=f"scores_{scores_name}_self")

@torch.no_grad()
def aggregate_self_scores(self, scores_name: str) -> None:
Expand Down
34 changes: 22 additions & 12 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,19 +557,29 @@ def _compute_low_rank_preconditioned_gradient(
List[torch.Tensor, torch.Tensor]:
Low-rank matrices that approximate the original preconditioned query gradient.
"""
U, S, V = torch.linalg.svd( # pylint: disable=not-callable
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype),
full_matrices=False,
)
rank = self.score_args.query_gradient_rank
U_k = U[:, :, :rank]
S_k = S[:, :rank]
# Avoids holding the full memory of the original tensor before indexing.
V_k = V[:, :rank, :].contiguous().clone()
return [
torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=self.score_args.score_dtype).contiguous().clone(),
V_k.to(dtype=self.score_args.score_dtype),
]
if self.score_args.use_full_svd:
U, S, V = torch.linalg.svd( # pylint: disable=not-callable
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype),
full_matrices=False,
)
U_k = U[:, :, :rank]
S_k = S[:, :rank]
# Avoids holding the full memory of the original tensor before indexing.
V_k = V[:, :rank, :].contiguous().clone()
return [
torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=self.score_args.score_dtype).contiguous().clone(),
V_k.to(dtype=self.score_args.score_dtype),
]
else:
U, S, V = torch.svd_lowrank(
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype),
q=rank,
)
return [
torch.matmul(U, torch.diag_embed(S)).to(dtype=self.score_args.score_dtype).contiguous().clone(),
V.transpose(1, 2).contiguous().to(dtype=self.score_args.score_dtype),
]

def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> None:
"""Computes the preconditioned per-sample-gradient.
Expand Down
7 changes: 4 additions & 3 deletions kronfluence/module/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def wrap_tracked_modules(
tracked_module_exists_dict = None
if tracked_module_names is not None:
tracked_module_exists_dict = {name: False for name in tracked_module_names}
per_sample_gradient_process_fnc = None
if task is not None and task.do_post_process_per_sample_gradient:
per_sample_gradient_process_fnc = task.post_process_per_sample_gradient

named_modules = model.named_modules()
for module_name, module in named_modules:
Expand All @@ -68,9 +71,7 @@ def wrap_tracked_modules(
tracked_module = TrackedModule.SUPPORTED_MODULES[type(module)](
name=module_name,
original_module=module,
per_sample_gradient_process_fnc=task.post_process_per_sample_gradient
if task.do_post_process_per_sample_gradient
else None,
per_sample_gradient_process_fnc=per_sample_gradient_process_fnc,
factor_args=factor_args,
score_args=score_args,
)
Expand Down
2 changes: 1 addition & 1 deletion kronfluence/utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def verify_models_equivalence(state_dict1: Dict[str, torch.Tensor], state_dict2:
for name in state_dict1:
tensor1 = state_dict1[name].to(dtype=torch.float32).cpu()
tensor2 = state_dict2[name].to(dtype=torch.float32).cpu()
if not torch.allclose(tensor1, tensor2, rtol=1e-3, atol=1e-5):
if not torch.allclose(tensor1, tensor2, rtol=1.3e-6, atol=1e-5):
return False

return True
26 changes: 3 additions & 23 deletions tests/modules/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from kronfluence.arguments import FactorArguments
from kronfluence.module.tracked_module import ModuleMode
from kronfluence.module.utils import set_mode, wrap_tracked_modules
from kronfluence.utils.save import verify_models_equivalence
from tests.utils import prepare_test


@pytest.mark.parametrize(
"test_name",
["mlp", "conv", "conv_bn", "bert", "gpt"],
["mlp", "conv_bn", "gpt"],
)
@pytest.mark.parametrize(
"mode",
Expand All @@ -34,6 +33,7 @@ def test_tracked_modules_forward_equivalence(
train_size: int,
seed: int,
) -> None:
# The forward pass should produce the same results with and without wrapped modules.
model, train_dataset, _, data_collator, task = prepare_test(
test_name=test_name,
train_size=train_size,
Expand Down Expand Up @@ -92,6 +92,7 @@ def test_tracked_modules_backward_equivalence(
train_size: int,
seed: int,
) -> None:
# The backward pass should produce the same results with and without wrapped modules.
model, train_dataset, _, data_collator, task = prepare_test(
test_name=test_name,
train_size=train_size,
Expand Down Expand Up @@ -131,24 +132,3 @@ def test_tracked_modules_backward_equivalence(
original_name = name.replace(".original_module", "")
if original_name in original_grads:
assert torch.allclose(grad, original_grads[original_name])


def test_verify_models_equivalence() -> None:
model1, _, _, _, _ = prepare_test(
test_name="mlp",
train_size=10,
seed=0,
)
model2, _, _, _, _ = prepare_test(
test_name="mlp",
train_size=10,
seed=1,
)
model3, _, _, _, _ = prepare_test(
test_name="conv",
train_size=10,
seed=1,
)
assert verify_models_equivalence(model1.state_dict(), model1.state_dict())
assert not verify_models_equivalence(model1.state_dict(), model2.state_dict())
assert not verify_models_equivalence(model1.state_dict(), model3.state_dict())
4 changes: 3 additions & 1 deletion tests/modules/test_per_sample_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def for_loop_per_sample_gradient(
"conv_bn",
"bert",
"roberta",
"gpt",
],
)
@pytest.mark.parametrize("use_measurement", [True, False])
Expand Down Expand Up @@ -192,6 +193,7 @@ def test_for_loop_per_sample_gradient_equivalence(
"conv_bn",
"bert",
"roberta",
"gpt",
],
)
@pytest.mark.parametrize("use_measurement", [True, False])
Expand Down Expand Up @@ -305,7 +307,7 @@ def test_mean_gradient_equivalence(
[
"mlp",
"conv",
"gpt",
"roberta",
],
)
@pytest.mark.parametrize("train_size", [32])
Expand Down
38 changes: 37 additions & 1 deletion tests/modules/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_query_gradient_svd(
@pytest.mark.parametrize("output_dim", [512, 1024])
@pytest.mark.parametrize("batch_dim", [8, 16])
@pytest.mark.parametrize("qbatch_dim", [8, 16])
@pytest.mark.parametrize("rank", [32])
@pytest.mark.parametrize("rank", [8])
@pytest.mark.parametrize("seed", [0])
def test_query_gradient_svd_reconst(
input_dim: int,
Expand Down Expand Up @@ -150,3 +150,39 @@ def test_query_gradient_svd_reconst(
assert intermediate2.numel() <= reconst_numel
else:
assert intermediate.numel() <= reconst_numel


def test_query_gradient_svd_vs_low_rank_svd(
seed: int = 0,
) -> None:
input_dim = 2048
output_dim = 1024
batch_dim = 16
set_seed(seed)

# gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float32)

rank = 32
lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank), dtype=torch.float32)
lr_gradient2 = torch.rand(size=(batch_dim, rank, input_dim), dtype=torch.float32)
gradient = torch.bmm(lr_gradient1, lr_gradient2)

U, S, V = torch.linalg.svd(
gradient.contiguous(),
full_matrices=False,
)

U_k = U[:, :, :rank]
S_k = S[:, :rank]
V_k = V[:, :rank, :].clone()
left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous()
assert torch.bmm(left, right).shape == gradient.shape
print(f"Error: {(torch.bmm(left, right) - gradient).norm()}")

new_U, new_S, new_V = torch.svd_lowrank(
gradient.contiguous(),
q=rank,
)
new_left, new_right = torch.matmul(new_U, torch.diag_embed(new_S)).contiguous(), new_V.transpose(1, 2).contiguous()
assert torch.bmm(new_left, new_right).shape == gradient.shape
print(f"Error: {(torch.bmm(new_left, new_right) - gradient).norm()}")
Loading

0 comments on commit db1c89f

Please sign in to comment.