diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index cfe1a4e..f0bcb2e 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -32,5 +32,6 @@ jobs: pytest -vx tests/factors/test_lambdas.py pytest -vx tests/modules/test_modules.py pytest -vx tests/modules/test_per_sample_gradients.py + pytest -vx tests/modules/test_matmul.py pytest -vx tests/scores/test_pairwise_scores.py pytest -vx tests/scores/test_self_scores.py \ No newline at end of file diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index b07e0ff..f20b660 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -201,7 +201,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) if damping_factor is None: damping_factor = 0.1 * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) - storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu") + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() def precondition_gradient( self, @@ -249,10 +249,10 @@ def requires_lambda_matrices_for_precondition(self) -> bool: def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to( dtype=score_args.precondition_dtype - ) + ).contiguous() storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to( dtype=score_args.precondition_dtype - ) + ).contiguous() activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device) gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device) lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) @@ -260,7 +260,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) if damping_factor is None: damping_factor = 0.1 * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) - storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu") + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[ACTIVATION_EIGENVALUES_NAME] = None storage[GRADIENT_EIGENVALUES_NAME] = None @@ -316,20 +316,19 @@ def requires_lambda_matrices_for_precondition(self) -> bool: def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to( dtype=score_args.precondition_dtype - ) + ).contiguous() storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to( dtype=score_args.precondition_dtype - ) + ).contiguous() storage[ACTIVATION_EIGENVALUES_NAME] = None storage[GRADIENT_EIGENVALUES_NAME] = None - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = 0.1 * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) - storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu") + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() @torch.no_grad() def precondition_gradient( diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 35a7456..3c7d37b 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -232,7 +232,7 @@ def fit_covariance_matrices_with_loader( state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) - del batch, attention_mask, loss + del loss total_steps += 1 pbar.update(1) diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index 90dbdd2..ff027ea 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -429,7 +429,7 @@ def fit_lambda_matrices_with_loader( state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) - del batch, loss + del loss total_steps += 1 pbar.update(1) diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 1d30cec..1becff0 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -159,7 +159,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") summed_gradient = contract("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0) - return summed_gradient.view((1, *summed_gradient.size())) + return summed_gradient def compute_per_sample_gradient( self, @@ -176,14 +176,12 @@ def compute_per_sample_gradient( ) return per_sample_gradient - @torch.no_grad() def compute_pairwise_score( - self, preconditioned_gradient, input_activation: torch.Tensor, output_gradient: torch.Tensor + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - if isinstance(preconditioned_gradient, list): left_mat, right_mat = preconditioned_gradient if self.einsum_expression is None: diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index d8d4a51..0fbf40a 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -77,7 +77,7 @@ def compute_per_sample_gradient( return per_sample_gradient def compute_pairwise_score( - self, preconditioned_gradient, input_activation: torch.Tensor, output_gradient: torch.Tensor + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) if isinstance(preconditioned_gradient, list): diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index b05e023..ec57ad8 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -31,8 +31,11 @@ class ModuleMode(str, BaseEnum): - """Enum representing a module's mode, indicating which factors and scores - need to be computed during forward and backward passes.""" + """Enum representing a module's mode for factor and score computation. + + This enum indicates which factors and scores need to be computed during + forward and backward passes. + """ DEFAULT = "default" COVARIANCE = "covariance" @@ -124,7 +127,7 @@ def __init__( self._initialize_storage() def _initialize_storage(self) -> None: - """Initializes trackers for different module modes.""" + """Initializes storage for various factors and scores.""" # Storage for activation and pseudo-gradient covariance matrices # for covariance_factor_name in COVARIANCE_FACTOR_NAMES: @@ -138,22 +141,44 @@ def _initialize_storage(self) -> None: for lambda_factor_name in LAMBDA_FACTOR_NAMES: self.storage[lambda_factor_name]: Optional[torch.Tensor] = None - # Storage for preconditioned query gradients and influence scores # + # Storage for preconditioned gradients and influence scores # self.storage[AGGREGATED_GRADIENT_NAME]: Optional[torch.Tensor] = None self.storage[PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None self.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None self.storage[PAIRWISE_SCORE_MATRIX_NAME]: Optional[torch.Tensor] = None self.storage[SELF_SCORE_VECTOR_NAME]: Optional[torch.Tensor] = None - def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: - """A forward pass of the tracked module. This should have identical behavior to that of the original module.""" + def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Performs a forward pass of the tracked module. + + This method should have identical behavior to that of the original module. + + Args: + inputs (torch.Tensor): + Input tensor to the module. + *args: + Variable length argument list. + **kwargs: + Arbitrary keyword arguments. + + Returns: + torch.Tensor: + The output of the forward pass. + """ outputs = self.original_module(inputs, *args, **kwargs) if outputs.requires_grad: return outputs return outputs + self._constant def prepare_storage(self, device: torch.device) -> None: - """Performs any necessary operations on storage before computing influence scores.""" + """Prepares storage for computing influence scores. + + This method performs necessary operations on storage before computing influence scores. + + Args: + device (torch.device): + The device to prepare the storage for. + """ FactorConfig.CONFIGS[self.factor_args.strategy].prepare( storage=self.storage, score_args=self.score_args, @@ -161,33 +186,73 @@ def prepare_storage(self, device: torch.device) -> None: ) def update_factor_args(self, factor_args: FactorArguments) -> None: - """Updates the factor arguments.""" + """Updates the factor arguments. + + Args: + factor_args (FactorArguments): + New factor arguments to set. + """ self.factor_args = factor_args def update_score_args(self, score_args: ScoreArguments) -> None: - """Updates the score arguments.""" + """Updates the score arguments. + + Args: + score_args (ScoreArguments): + New score arguments to set. + """ self.score_args = score_args def get_factor(self, factor_name: str) -> Optional[torch.Tensor]: - """Returns the factor with the given name.""" + """Retrieves a factor by name from storage. + + Args: + factor_name (str): + The name of the factor to retrieve. + + Returns: + Optional[torch.Tensor]: + The requested factor, or `None` if not found. + """ if factor_name not in self.storage or self.storage[factor_name] is None: return None return self.storage[factor_name] def release_factor(self, factor_name: str) -> None: - """Release the factor with the given name from memory.""" + """Releases a factor from memory. + + Args: + factor_name (str): + The name of the factor to release. + """ if factor_name not in self.storage or self.storage[factor_name] is None: return None del self.storage[factor_name] self.storage[factor_name] = None def set_factor(self, factor_name: str, factor: Any) -> None: - """Sets the factor with the given name.""" + """Sets a factor in storage. + + Args: + factor_name (str): + The name of the factor to set. + factor (Any): + The factor value to store. + """ if factor_name in self.storage: self.storage[factor_name] = factor def set_mode(self, mode: ModuleMode, release_memory: bool = False) -> None: - """Sets the module mode of the current `TrackedModule` instance.""" + """Sets the operating mode of the `TrackedModule`. + + This method changes the current mode and manages associated trackers and memory. + + Args: + mode (ModuleMode): + The new mode to set. + release_memory (bool): + Whether to release memory for all trackers. + """ self._trackers[self.current_mode].release_hooks() self.einsum_expression = None self.current_mode = mode @@ -199,11 +264,21 @@ def set_mode(self, mode: ModuleMode, release_memory: bool = False) -> None: self._trackers[self.current_mode].register_hooks() def set_attention_mask(self, attention_mask: Optional[torch.Tensor] = None) -> None: - """Sets the attention mask for activation covariance computations.""" + """Sets the attention mask for activation covariance computations. + + Args: + attention_mask (torch.Tensor, optional): + The attention mask to set. + """ self.attention_mask = attention_mask def set_gradient_scale(self, scale: float = 1.0) -> None: - """Sets the scale of the gradient obtained from `GradScaler`.""" + """Sets the scale of the gradient obtained from `GradScaler`. + + Args: + scale (float): + The scale factor to set. + """ self.gradient_scale = scale def finalize_iteration(self) -> None: @@ -211,7 +286,12 @@ def finalize_iteration(self) -> None: self._trackers[self.current_mode].finalize_iteration() def exist(self) -> bool: - """Checks if the desired statistics are available.""" + """Checks if the desired statistics are available. + + Returns: + bool: + `True` if statistics exist, `False` otherwise. + """ return self._trackers[self.current_mode].exist() def synchronize(self, num_processes: int) -> None: diff --git a/kronfluence/module/tracker/base.py b/kronfluence/module/tracker/base.py index 1e0fc3f..c6f603f 100644 --- a/kronfluence/module/tracker/base.py +++ b/kronfluence/module/tracker/base.py @@ -76,7 +76,12 @@ def finalize_iteration(self) -> None: """Finalizes statistics for the current iteration.""" def exist(self) -> bool: - """Checks if the desired statistics are available.""" + """Checks if the desired statistics are available. + + Returns: + bool: + `True` if statistics exist, `False` otherwise. + """ return False def synchronize(self, num_processes: int) -> None: diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py index 395b3ac..c649f97 100644 --- a/kronfluence/module/tracker/factor.py +++ b/kronfluence/module/tracker/factor.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Union import torch import torch.distributed as dist @@ -25,42 +25,51 @@ class CovarianceTracker(BaseTracker): """Tracks and computes activation and gradient covariance matrices for a given module.""" - def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -> None: + _activation_covariance_initialized: bool = False + _gradient_covariance_initialized: bool = False + + def _update_activation_covariance_matrix( + self, input_activation: torch.Tensor, count: Union[torch.Tensor, int] + ) -> None: """Computes and updates the activation covariance matrix. Args: input_activation (torch.Tensor): - The input tensor to the module, provided by PyTorch's forward hook. + The flattened input tensor to the module, provided by PyTorch's forward hook. + count (int): + The number of activations. """ - flattened_activation, count = self.module.get_flattened_activation(input_activation=input_activation) - - if self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] is None: + if not self._activation_covariance_initialized: self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros( size=(1,), dtype=torch.int64, device=count.device if isinstance(count, torch.Tensor) else None, requires_grad=False, ) - dimension = flattened_activation.size(1) + dimension = input_activation.size(1) self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME] = torch.zeros( size=(dimension, dimension), - dtype=flattened_activation.dtype, - device=flattened_activation.device, + dtype=input_activation.dtype, + device=input_activation.device, requires_grad=False, ) + self._activation_covariance_initialized = True self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count) - self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) + self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(input_activation.t(), input_activation) - def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> None: + def _update_gradient_covariance_matrix( + self, output_gradient: torch.Tensor, count: Union[torch.Tensor, int] + ) -> None: """Computes and updates the pseudo-gradient covariance matrix. Args: output_gradient (torch.Tensor): - The gradient tensor with respect to the output of the module, provided by PyTorch's backward hook. + The flattened gradient tensor with respect to the output of the module, provided + by PyTorch's backward hook. + count (int): + The number of gradients. """ - flattened_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient) - - if self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None: + if not self._gradient_covariance_initialized: # In most cases, `NUM_GRADIENT_COVARIANCE_PROCESSED` and `NUM_ACTIVATION_COVARIANCE_PROCESSED` are # identical. However, they may differ when using gradient checkpointing or `torch.compile()`. self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( @@ -69,15 +78,16 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N device=count.device if isinstance(count, torch.Tensor) else None, requires_grad=False, ) - dimension = flattened_gradient.size(1) + dimension = output_gradient.size(1) self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME] = torch.zeros( size=(dimension, dimension), - dtype=flattened_gradient.dtype, - device=flattened_gradient.device, + dtype=output_gradient.dtype, + device=output_gradient.device, requires_grad=False, ) + self._gradient_covariance_initialized = True self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count) - self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) + self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient) def register_hooks(self) -> None: """Sets up hooks to compute activation and gradient covariance matrices.""" @@ -94,7 +104,8 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. ) ) # Computes and updates activation covariance during forward pass. - self._update_activation_covariance_matrix(input_activation=input_activation) + input_activation, count = self.module.get_flattened_activation(input_activation=input_activation) + self._update_activation_covariance_matrix(input_activation=input_activation, count=count) self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() @@ -102,10 +113,11 @@ def backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() output_gradient = self._preprocess_gradient( - output_gradient, target_dtype=self.module.factor_args.gradient_covariance_dtype + output_gradient.detach(), target_dtype=self.module.factor_args.gradient_covariance_dtype ) # Computes and updates pseudo-gradient covariance during backward pass. - self._update_gradient_covariance_matrix(output_gradient=output_gradient) + output_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient) + self._update_gradient_covariance_matrix(output_gradient=output_gradient, count=count) self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @@ -130,6 +142,8 @@ def synchronize(self, num_processes: int) -> None: def release_memory(self) -> None: """Clears all covariance matrices from memory.""" + self._activation_covariance_initialized = False + self._gradient_covariance_initialized = False for covariance_factor_name in COVARIANCE_FACTOR_NAMES: self.module.storage[covariance_factor_name] = None @@ -246,13 +260,14 @@ def backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() output_gradient = self._preprocess_gradient( - output_gradient=output_gradient, target_dtype=self.module.factor_args.per_sample_gradient_dtype + output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype ) per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, ).to(dtype=self.module.factor_args.lambda_dtype) self.clear_all_cache() + del output_gradient # Computes and updates lambda matrix during backward pass. self._update_lambda_matrix(per_sample_gradient=per_sample_gradient) @@ -261,7 +276,7 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() output_gradient = self._preprocess_gradient( - output_gradient=output_gradient, target_dtype=self.module.factor_args.per_sample_gradient_dtype + output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype ) cached_activation = self.cached_activations.pop() per_sample_gradient = self.module.compute_per_sample_gradient( diff --git a/kronfluence/module/tracker/gradient.py b/kronfluence/module/tracker/gradient.py index 82863ec..b7a3de3 100644 --- a/kronfluence/module/tracker/gradient.py +++ b/kronfluence/module/tracker/gradient.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist -import torch.nn as nn +from torch import nn from kronfluence.module.tracker.base import BaseTracker from kronfluence.utils.constants import AGGREGATED_GRADIENT_NAME @@ -14,23 +14,22 @@ class GradientTracker(BaseTracker): def register_hooks(self) -> None: """Sets up hooks to compute and keep track of aggregated gradient.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.per_sample_gradient_dtype, - copy=True, - ) - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() @@ -39,15 +38,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._raise_cache_not_found_exception() handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale - + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) if isinstance(self.cached_activations, list): cached_activation = self.cached_activations.pop() else: @@ -57,22 +50,20 @@ def backward_hook(output_gradient: torch.Tensor) -> None: input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) + self.clear_all_cache() else: summed_gradient = self.module.compute_per_sample_gradient( input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ).sum(dim=0, keepdim=True) - self.clear_all_cache() - if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False) self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient) self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) - @torch.no_grad() def finalize_iteration(self): - """Clears all cached activations from memory.""" + """Clears all cached data from memory.""" self.clear_all_cache() def exist(self) -> bool: diff --git a/kronfluence/module/tracker/pairwise_score.py b/kronfluence/module/tracker/pairwise_score.py index 94696ee..ee0c28f 100644 --- a/kronfluence/module/tracker/pairwise_score.py +++ b/kronfluence/module/tracker/pairwise_score.py @@ -1,8 +1,8 @@ from typing import Tuple import torch -import torch.nn as nn from opt_einsum import DynamicProgramming, contract, contract_expression +from torch import nn from kronfluence.module.tracker.base import BaseTracker from kronfluence.utils.constants import ( @@ -50,23 +50,22 @@ def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tenso def register_hooks(self) -> None: """Sets up hooks to compute pairwise influence scores.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.score_dtype, - copy=True, - ) - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() @@ -75,15 +74,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._raise_cache_not_found_exception() handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.score_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale - + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.score_dtype + ) if isinstance(self.cached_activations, list): cached_activation = self.cached_activations.pop() else: @@ -101,14 +94,13 @@ def backward_hook(output_gradient: torch.Tensor) -> None: input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) - self.clear_all_cache() + del cached_activation, output_gradient self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient) self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) - @torch.no_grad() def finalize_iteration(self) -> None: - """Clears all cached activations from memory.""" + """Clears all cached data from memory.""" self.clear_all_cache() def exist(self) -> bool: @@ -119,12 +111,13 @@ def accumulate_iterations(self) -> None: """Removes pairwise scores from memory after a single iteration.""" self.release_memory() + @torch.no_grad() def finalize_all_iterations(self) -> None: """Removes cached preconditioned gradient from memory. Additionally, if aggregated gradients are available, computes the pairwise score using them.""" if self.module.storage[AGGREGATED_GRADIENT_NAME] is not None: self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].to( - dtype=self.module.score_args.precondition_dtype + dtype=self.module.score_args.score_dtype ) self._compute_pairwise_score_with_gradient( per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME] diff --git a/kronfluence/module/tracker/precondition.py b/kronfluence/module/tracker/precondition.py index a566bfe..b93c397 100644 --- a/kronfluence/module/tracker/precondition.py +++ b/kronfluence/module/tracker/precondition.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist -import torch.nn as nn +from torch import nn from kronfluence.factor.config import FactorConfig from kronfluence.module.tracker.base import BaseTracker @@ -51,19 +51,13 @@ def _compute_low_rank_preconditioned_gradient( V = V.transpose(1, 2).to(dtype=target_dtype) return [left_mat, V] - def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the preconditioned per-sample gradient. + def _process_preconditioned_gradient(self, preconditioned_gradient: torch.Tensor) -> None: + """Processes the preconditioned per-sample gradient. Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. + preconditioned_gradient (torch.Tensor): + The preconditioned per-sample gradient tensor for the given batch. """ - preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( - gradient=per_sample_gradient, - storage=self.module.storage, - ) - del per_sample_gradient - if ( self.module.score_args.query_gradient_low_rank is not None and min(preconditioned_gradient.size()[1:]) > self.module.score_args.query_gradient_low_rank @@ -83,23 +77,22 @@ def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> def register_hooks(self) -> None: """Sets up hooks to compute preconditioned per-sample gradient.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.per_sample_gradient_dtype, - copy=True, - ) - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation self.cached_hooks.append( outputs.register_hook( shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook @@ -112,34 +105,30 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._raise_cache_not_found_exception() handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, ).to(dtype=self.module.score_args.precondition_dtype) self.clear_all_cache() + del output_gradient # Computes preconditioned per-sample gradient during backward pass. - self._compute_preconditioned_gradient(per_sample_gradient=per_sample_gradient) + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=per_sample_gradient, + storage=self.module.storage, + ) + del per_sample_gradient + self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient) @torch.no_grad() def shared_backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) cached_activation = self.cached_activations.pop() per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=cached_activation.to(device=output_gradient.device), @@ -159,7 +148,12 @@ def finalize_iteration(self) -> None: self.cached_per_sample_gradient = self.cached_per_sample_gradient.to( dtype=self.module.score_args.precondition_dtype ) - self._compute_preconditioned_gradient(per_sample_gradient=self.cached_per_sample_gradient) + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=self.cached_per_sample_gradient, + storage=self.module.storage, + ) + self.cached_per_sample_gradient = None + self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient) self.clear_all_cache() def exist(self) -> bool: @@ -211,13 +205,13 @@ def truncate(self, keep_size: int) -> None: if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): assert len(self.module.storage[PRECONDITIONED_GRADIENT_NAME]) == 2 self.module.storage[PRECONDITIONED_GRADIENT_NAME] = [ - self.module.storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size], - self.module.storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size], + self.module.storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size].clone(), + self.module.storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size].clone(), ] else: self.module.storage[PRECONDITIONED_GRADIENT_NAME] = self.module.storage[PRECONDITIONED_GRADIENT_NAME][ :keep_size - ] + ].clone() def accumulate_iterations(self) -> None: """Accumulates preconditioned gradient across multiple iterations.""" @@ -245,14 +239,19 @@ def accumulate_iterations(self) -> None: del gradient, self.module.storage[PRECONDITIONED_GRADIENT_NAME] self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + @torch.no_grad() def finalize_all_iterations(self) -> None: """Preconditions aggregated gradient if it exists in storage.""" if self.module.storage[AGGREGATED_GRADIENT_NAME] is not None: self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].to( dtype=self.module.score_args.precondition_dtype ) - self._compute_preconditioned_gradient(per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME]) + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=self.module.storage[AGGREGATED_GRADIENT_NAME], + storage=self.module.storage, + ) self.module.storage[AGGREGATED_GRADIENT_NAME] = None + self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient) self.accumulate_iterations() def release_memory(self) -> None: diff --git a/kronfluence/module/tracker/self_score.py b/kronfluence/module/tracker/self_score.py index 248c0e4..c8230dd 100644 --- a/kronfluence/module/tracker/self_score.py +++ b/kronfluence/module/tracker/self_score.py @@ -1,9 +1,9 @@ from typing import Tuple import torch -import torch.nn as nn +from torch import nn -from kronfluence.factor.config import FactorConfig +from kronfluence.factor.config import STORAGE_TYPE, FactorConfig from kronfluence.module.tracker.base import BaseTracker from kronfluence.utils.constants import ( PRECONDITIONED_GRADIENT_NAME, @@ -11,14 +11,21 @@ ) -def move_storage_to_device(storage, target_device: torch.device) -> None: - """Moves stored factors into the target device.""" +def move_storage_to_device(storage: STORAGE_TYPE, target_device: torch.device) -> None: + """Moves all stored factors in the storage dictionary to the specified target device. + + Args: + storage (STORAGE_TYPE): + A dictionary containing stored factors. + target_device (torch.device): + The target device to move the factors to. + """ for name, factor in storage.items(): if factor is not None: if isinstance(factor, list): for i in range(len(storage[name])): storage[name][i] = factor[i].to(device=target_device) - else: + if isinstance(factor, torch.Tensor): storage[name] = factor.to(device=target_device) @@ -56,23 +63,22 @@ def _compute_self_score(self, per_sample_gradient: torch.Tensor) -> None: def register_hooks(self) -> None: """Sets up hooks to compute self-influence scores.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.per_sample_gradient_dtype, - copy=True, - ) - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation self.cached_hooks.append( outputs.register_hook( shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook @@ -85,33 +91,24 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._raise_cache_not_found_exception() handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, ).to(dtype=self.module.score_args.precondition_dtype) self.clear_all_cache() + del output_gradient self._compute_self_score(per_sample_gradient=per_sample_gradient) @torch.no_grad() def shared_backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype + ) cached_activation = self.cached_activations.pop() per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=cached_activation.to(device=output_gradient.device), @@ -144,6 +141,8 @@ def accumulate_iterations(self) -> None: def release_memory(self) -> None: """Releases self-influence scores from memory.""" self.clear_all_cache() + if self.storage_at_device: + move_storage_to_device(storage=self.module.storage, target_device=torch.device("cpu")) self.storage_at_device = False del self.module.storage[SELF_SCORE_VECTOR_NAME] self.module.storage[SELF_SCORE_VECTOR_NAME] = None @@ -169,24 +168,24 @@ def _compute_self_measurement_score_with_gradient(self, per_sample_gradient: tor self.module.storage[SELF_SCORE_VECTOR_NAME].add_(scores) def register_hooks(self) -> None: - """Sets up hooks to compute pairwise influence scores.""" + """Sets up hooks to compute self-influence scores with measurement.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.score_dtype, - copy=True, - ) - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() @@ -200,22 +199,16 @@ def backward_hook(output_gradient: torch.Tensor) -> None: target_device=output_gradient.device, ) self.storage_at_device = True + handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.score_args.score_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale - + output_gradient = self._preprocess_gradient( + output_gradient.detach(), target_dtype=self.module.score_args.score_dtype + ) if isinstance(self.cached_activations, list): cached_activation = self.cached_activations.pop() else: cached_activation = self.cached_activations - if self.module.per_sample_gradient_process_fnc is None: scores = self.module.compute_self_measurement_score( preconditioned_gradient=self.module.storage[PRECONDITIONED_GRADIENT_NAME], @@ -233,14 +226,13 @@ def backward_hook(output_gradient: torch.Tensor) -> None: input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) - self.clear_all_cache() + del cached_activation, output_gradient self._compute_self_measurement_score_with_gradient(per_sample_gradient=per_sample_gradient) self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) - @torch.no_grad() def finalize_iteration(self) -> None: - """Removes all cached activations from memory.""" + """Clears all cached data from memory.""" self.clear_all_cache() def exist(self) -> bool: @@ -254,6 +246,8 @@ def accumulate_iterations(self) -> None: def release_memory(self) -> None: """Releases self-influence scores from memory.""" self.clear_all_cache() + if self.storage_at_device: + move_storage_to_device(storage=self.module.storage, target_device=torch.device("cpu")) self.storage_at_device = False del self.module.storage[SELF_SCORE_VECTOR_NAME] self.module.storage[SELF_SCORE_VECTOR_NAME] = None diff --git a/kronfluence/module/utils.py b/kronfluence/module/utils.py index 29b2a13..e858121 100644 --- a/kronfluence/module/utils.py +++ b/kronfluence/module/utils.py @@ -13,8 +13,18 @@ def _get_submodules(model: nn.Module, key: str) -> Tuple[nn.Module, str]: - """Returns the parent module and its name given the name of the current module.""" - # The code is modified from: https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py. + """Retrieves the parent module and its name given the name of the current module. + + Args: + model (nn.Module): + The PyTorch model to inspect. + key (str): + The full name of the current module. + + Returns: + Tuple[nn.Module, str]: + The parent module and the name of the target module. + """ parent = model.get_submodule(".".join(key.split(".")[:-1])) target_name = key.split(".")[-1] return parent, target_name @@ -34,13 +44,13 @@ def wrap_tracked_modules( task (Task): The specific task associated with the model. factor_args (FactorArguments, optional): - Arguments related to computing the influence factors. + Arguments related to computing influence factors. score_args (ScoreArguments, optional): - Arguments related to computing the influence scores. + Arguments related to computing influence scores. Returns: nn.Module: - The wrapped Pytorch model with `TrackedModule` installed. + The processed model with `TrackedModule` installed. """ if isinstance(model, (DP, DDP, FSDP)): raise ValueError( @@ -48,7 +58,6 @@ def wrap_tracked_modules( "or FullyShardedDataParallel. Call `prepare_model` before wrapping the model." ) - tracked_module_count = 0 tracked_module_names = task.get_influence_tracked_modules() if task is not None else None tracked_module_exists_dict = None if tracked_module_names is not None: @@ -77,7 +86,6 @@ def wrap_tracked_modules( ) parent, target_name = _get_submodules(model=model, key=module_name) setattr(parent, target_name, tracked_module) - tracked_module_count += 1 if tracked_module_exists_dict is not None: tracked_module_exists_dict[module_name] = True @@ -88,22 +96,34 @@ def wrap_tracked_modules( ) raise IllegalTaskConfigurationError(error_msg) - if tracked_module_count == 0: - supported_modules_names = [module.__name__ for module in TrackedModule.SUPPORTED_MODULES] - error_msg = ( - f"Kronfluence currently supports following PyTorch modules: `{supported_modules_names}`. " - f"However, these modules were not found in the provided model. If you want to analyze " - "custom layers, consider rewriting your model to use the supported modules, " - "or define your own custom module by subclassing `TrackedModule`." + if not any(isinstance(module, TrackedModule) for module in model.modules()): + supported_modules = ", ".join(module.__name__ for module in TrackedModule.SUPPORTED_MODULES) + raise IllegalTaskConfigurationError( + f"No supported modules found. Kronfluence supports: {supported_modules}. " + "Consider rewriting your model or subclassing `TrackedModule` for custom layers.\n" + f"Current Model:\n{model}" ) - error_msg += f"\nCurrent Model:\n{model}" - raise IllegalTaskConfigurationError(error_msg) - return model def make_modules_partition(total_module_names: List[str], partition_size: int) -> List[List[str]]: - """Divides a list of module names into smaller partitions of a specified size.""" + """Divides a list of module names into smaller partitions of a specified size. + + Args: + total_module_names (List[str]): + The list of all module names. + partition_size (int): + The number of partitions to create. + + Returns: + List[List[str]]: + A list of partitioned module names. + + Raises: + ValueError: If `len(total_module_names)` is less than `partition_size`. + """ + if len(total_module_names) < partition_size: + raise ValueError("The total modules must be equal to or greater than the partition size.") # See https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length. div, mod = divmod(len(total_module_names), partition_size) return list( @@ -117,20 +137,17 @@ def set_mode( tracked_module_names: List[str] = None, release_memory: bool = False, ) -> None: - """Sets the module mode of all `TrackedModule` instances within a model. For example, to compute - and update covariance matrices, the module mode needs to be set to `ModuleMode.COVARIANCE`. If - `tracked_module_names` are provided, the module mode is only set for modules listed in `tracked_module_names`. + """Sets the module mode of specified `TrackedModule` instances within a model. Args: model (nn.Module): - The PyTorch model which contains `TrackedModule`. + The PyTorch model containing `TrackedModule` instances. mode (ModuleMode): The new mode to set for `TrackedModule`. tracked_module_names (List[str], optional): - The list of names for `TrackedModule` to set the new mode. If not provided, the new mode is - set for all available `TrackedModule` within the model. + Names of modules to update. If `None`, updates all. release_memory (bool, optional): - If `False`, existing factors are kept in memory. + If `True`, releases memory of existing factors. """ for module in model.modules(): if isinstance(module, TrackedModule): @@ -140,26 +157,45 @@ def set_mode( def update_factor_args(model: nn.Module, factor_args: FactorArguments) -> None: - """Updates the factor arguments for all `TrackedModule` instances within a model.""" + """Updates the factor arguments for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + factor_args (FactorArguments): + The new factor arguments to set. + """ for module in model.modules(): if isinstance(module, TrackedModule): module.update_factor_args(factor_args=factor_args) def update_score_args(model: nn.Module, score_args: ScoreArguments) -> None: - """Updates the score arguments for all `TrackedModule` instances within a model.""" + """Updates the score arguments for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + score_args (ScoreArguments): + The new score arguments to set. + """ for module in model.modules(): if isinstance(module, TrackedModule): module.update_score_args(score_args=score_args) def get_tracked_module_names(model: nn.Module) -> List[str]: - """Returns the names of `TrackedModule` instances within a model.""" - tracked_modules = [] - for module in model.modules(): - if isinstance(module, TrackedModule): - tracked_modules.append(module.name) - return tracked_modules + """Returns the names of `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model to inspect. + + Returns: + List[str]: + A list of names of `TrackedModule` instances. + """ + return [module.name for module in model.modules() if isinstance(module, TrackedModule)] def load_factors( @@ -168,8 +204,22 @@ def load_factors( tracked_module_names: List[str] = None, cpu: bool = True, ) -> Dict[str, torch.Tensor]: - """Loads factors with the given name from all `TrackedModule` instances within a model (or all modules listed - in `tracked_module_names` if not `None`).""" + """Loads factors with the given name from specified `TrackedModule` instances. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + factor_name (str): + The name of the factor to load. + tracked_module_names (Optional[List[str]]): + Names of modules to load from. If `None`, loads from all. + cpu (bool): + If `True`, moves factors to CPU and releases GPU memory. + + Returns: + Dict[str, torch.Tensor]: + A dictionary of loaded factors, keyed by module name. + """ loaded_factors = {} for module in model.modules(): if isinstance(module, TrackedModule): @@ -186,7 +236,18 @@ def load_factors( def set_factors(model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor], clone: bool = False) -> None: - """Sets new factor for all `TrackedModule` instances within a model.""" + """Sets new factors for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + factor_name (str): + The name of the factor to set. + factors (Dict[str, torch.Tensor]): + A dictionary of factors to set, keyed by module name. + clone (bool): + If `True`, clones the factors before setting. + """ for module in model.modules(): if isinstance(module, TrackedModule): module.set_factor( @@ -198,7 +259,18 @@ def set_attention_mask( model: nn.Module, attention_mask: Optional[Union[Dict[str, torch.Tensor], torch.Tensor]] = None, ) -> None: - """Sets the attention mask for all `TrackedModule` instances within a model.""" + """Sets the attention mask for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + attention_mask (Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]): + The attention mask to set. Can be a dictionary (keyed by module name) or a single tensor. + + Raises: + RuntimeError: + If an invalid attention mask is provided. + """ for module in model.modules(): if isinstance(module, TrackedModule): if isinstance(attention_mask, dict): @@ -218,164 +290,124 @@ def set_gradient_scale( model: nn.Module, gradient_scale: float = 1.0, ) -> None: - """Sets the gradient scale for all `TrackedModule` instances within a model.""" + """Sets the gradient scale for all `TrackedModule` instances within a model. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + gradient_scale (float): + The gradient scale to set. + """ for module in model.modules(): if isinstance(module, TrackedModule): module.set_gradient_scale(scale=gradient_scale) def prepare_modules(model: nn.Module, tracked_module_names: List[str], device: torch.device) -> None: + """Prepares specified `TrackedModule` instances for score computation. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to prepare. + device (torch.device): + The device to prepare the modules for. + """ for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: module.prepare_storage(device=device) def synchronize_modules(model: nn.Module, tracked_module_names: List[str], num_processes: int = 1) -> None: + """Synchronizes specified `TrackedModule` instances across processes. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to synchronize. + num_processes (int): + The number of processes to synchronize across. + """ for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: module.synchronize(num_processes=num_processes) def truncate(model: nn.Module, tracked_module_names: List[str], keep_size: int) -> None: + """Truncates the data in specified `TrackedModule` instances. + + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to truncate. + keep_size (int): + The number of elements to keep after truncation. + """ for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: module.truncate(keep_size=keep_size) def exist_for_all_modules(model: nn.Module, tracked_module_names: List[str]) -> bool: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if not module.exist(): - return False - return True - - -def accumulate_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.accumulate_iterations() - + """Checks if all specified `TrackedModule` instances have existing factor. -def finalize_iteration(model: nn.Module, tracked_module_names: List[str]) -> None: - """Updates Lambda matrices for all modules listed in `tracked_module_names`.""" - for name, module in model.named_modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_iteration() - - -def finalize_all_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: - """Updates Lambda matrices for all modules listed in `tracked_module_names`.""" - for name, module in model.named_modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_all_iterations() - - -def finalize_preconditioned_gradient(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes preconditioned gradient for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_preconditioned_gradient() - - -def accumulate_preconditioned_gradient(model: nn.Module, tracked_module_names: List[str]) -> None: - """Accumulates preconditioned gradient for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.accumulate_preconditioned_gradient() - - -def release_preconditioned_gradient(model: nn.Module) -> None: - """Releases preconditioned gradient of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.release_preconditioned_gradient() - - -def truncate_preconditioned_gradient(model: nn.Module, tracked_module_names: List[str], keep_size: int) -> None: - """Truncates preconditioned gradient for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.truncate_preconditioned_gradient(keep_size=keep_size) - - -def synchronize_preconditioned_gradient(model: nn.Module, tracked_module_names: List[str], num_processes: int) -> None: - """Synchronizes preconditioned gradient for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.synchronize_preconditioned_gradient(num_processes=num_processes) - - -def release_scores(model: nn.Module) -> None: - """Releases scores of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.release_scores() - - -def finalize_pairwise_scores(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes pairwise influence scores for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_pairwise_score() + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to check. + Returns: + bool: + `True` if all specified modules have existing factor, `False` otherwise. + """ + return all( + module.exist() + for module in model.modules() + if isinstance(module, TrackedModule) and module.name in tracked_module_names + ) -def finalize_self_scores(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes self-influence scores for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_self_score() +def accumulate_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: + """Accumulates iterations for specified `TrackedModule` instances. -def finalize_self_measurement_scores(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes self-influence scores with measurement for all modules listed in `tracked_module_names`.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to accumulate iterations for. + """ for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_self_measurement_score() + module.accumulate_iterations() -def finalize_gradient_aggregation(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes aggregated gradient for all modules listed in `tracked_module_names`.""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_gradient_aggregation() - +def finalize_iteration(model: nn.Module, tracked_module_names: List[str]) -> None: + """Finalizes the current iteration for specified `TrackedModule` instances. -def synchronize_aggregated_gradient(model: nn.Module, tracked_module_names: List[str]) -> None: - """Synchronizes aggregated gradient for all modules listed in `tracked_module_names`.""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to finalize iteration for. + """ for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.synchronize_aggregated_gradient() - - -def release_aggregated_gradient(model: nn.Module) -> None: - """Releases aggregated gradient of all `TrackedModule` instances within a model.""" - for module in model.modules(): - if isinstance(module, TrackedModule): - module.release_aggregated_gradient() - - -def aggregated_gradient_exist(model: nn.Module, tracked_module_names: List[str]) -> bool: - """Checks if the aggregated gradient is computed for all modules listed in `tracked_module_names`.""" - exists = True - for name, module in model.named_modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - and module.aggregated_gradient is None - ): - exists = False - return exists - + module.finalize_iteration() -def compute_preconditioned_gradient_from_aggregation(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes preconditioned aggregated gradient for all modules listed in `tracked_module_names`""" - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.compute_preconditioned_gradient_from_aggregation() +def finalize_all_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: + """Finalizes all iterations for specified `TrackedModule` instances. -def compute_pairwise_scores_from_aggregation(model: nn.Module, tracked_module_names: List[str]) -> None: - """Computes preconditioned aggregated gradient for all modules listed in `tracked_module_names`""" + Args: + model (nn.Module): + The PyTorch model containing `TrackedModule` instances. + tracked_module_names (List[str]): + Names of modules to finalize all iterations for. + """ for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.compute_pairwise_scores_from_aggregation() + module.finalize_all_iterations() diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py index d8101a5..8467e00 100644 --- a/kronfluence/score/dot_product.py +++ b/kronfluence/score/dot_product.py @@ -95,31 +95,32 @@ def compute_dot_products_with_loader( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) - if score_args.compute_per_module_scores: - for module in cached_module_lst: - score_chunks[module.name].append( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() - ) - else: - pairwise_scores = None - for module in cached_module_lst: - if pairwise_scores is None: - pairwise_scores = torch.zeros_like( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + with torch.no_grad(): + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to(device="cpu", copy=True) ) - try: - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) - except RuntimeError: - if score_args.compute_per_token_scores: - raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) - raise - score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) - accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + else: + pairwise_scores = None + for module in cached_module_lst: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError as exc: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) from exc + raise + score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: state.wait_for_everyone() - del batch, loss + del loss total_steps += 1 pbar.update(1) @@ -141,9 +142,11 @@ def compute_dot_products_with_loader( gather_list = None if state.is_main_process: gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) + dist.gather(total_scores[module_name], gather_list) if state.is_main_process: total_scores[module_name] = torch.cat(gather_list, dim=1)[:, :dataset_size].cpu() + else: + total_scores[module_name] = total_scores[module_name].cpu() state.wait_for_everyone() return total_scores @@ -203,7 +206,7 @@ def compute_aggregated_dot_products_with_loader( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) - del batch, loss + del loss pbar.update(1) if state.use_distributed: @@ -217,26 +220,29 @@ def compute_aggregated_dot_products_with_loader( ) finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - scores[module.name] = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() - else: - pairwise_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if pairwise_scores is None: - pairwise_scores = torch.zeros_like( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + with torch.no_grad(): + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + scores[module.name] = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to( + device="cpu", copy=True ) - try: - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) - except RuntimeError: - if score_args.compute_per_token_scores: - raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) - raise - scores[ALL_MODULE_NAME] = pairwise_scores.cpu() - accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + else: + pairwise_scores = None + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError as exc: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) from exc + raise + scores[ALL_MODULE_NAME] = pairwise_scores.cpu() + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) model.zero_grad(set_to_none=True) set_mode( diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index 2a993cb..abc4a1f 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -32,7 +32,7 @@ from kronfluence.task import Task from kronfluence.utils.constants import FACTOR_TYPE, PARTITION_TYPE, SCORE_TYPE from kronfluence.utils.logger import TQDM_BAR_FORMAT -from kronfluence.utils.state import State, no_sync, release_memory +from kronfluence.utils.state import State, no_sync def pairwise_scores_save_path( @@ -70,7 +70,7 @@ def save_pairwise_scores( Args: output_dir (Path): Directory to save the scores. - scores (FACTOR_TYPE): + scores (SCORE_TYPE): Dictionary of scores to save. partition (PARTITION_TYPE, optional): Partition information, if any. @@ -87,7 +87,7 @@ def save_pairwise_scores( def load_pairwise_scores( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, -) -> Dict[str, torch.Tensor]: +) -> SCORE_TYPE: """Loads pairwise scores from disk. Args: @@ -171,7 +171,7 @@ def compute_pairwise_scores_with_loaders( Whether to disable the progress bar. Defaults to `False`. Returns: - Dict[str, torch.Tensor]: + SCORE_TYPE: A dictionary containing the module name and its pairwise influence scores. """ update_factor_args(model=model, factor_args=factor_args) @@ -190,6 +190,7 @@ def compute_pairwise_scores_with_loaders( model=model, factor_name=name, factors=loaded_factors[name], + clone=True, ) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) @@ -271,6 +272,7 @@ def compute_pairwise_scores_with_loaders( if module_name not in total_scores_chunks: total_scores_chunks[module_name] = [] total_scores_chunks[module_name].append(current_scores) + del scores state.wait_for_everyone() num_accumulations = 0 @@ -318,7 +320,7 @@ def compute_pairwise_query_aggregated_scores_with_loaders( ) if len(loaded_factors) > 0: for name in loaded_factors: - set_factors(model=model, factor_name=name, factors=loaded_factors[name]) + set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) enable_amp = score_args.amp_dtype is not None @@ -354,7 +356,7 @@ def compute_pairwise_query_aggregated_scores_with_loaders( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) - del query_batch, measurement + del measurement pbar.update(1) if state.use_distributed: diff --git a/kronfluence/score/self.py b/kronfluence/score/self.py index e501b17..a1466ce 100644 --- a/kronfluence/score/self.py +++ b/kronfluence/score/self.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional import torch +import torch.distributed as dist from accelerate.utils import send_to_device from safetensors.torch import load_file, save_file from torch import autocast, nn @@ -15,12 +16,8 @@ from kronfluence.module.utils import ( accumulate_iterations, finalize_iteration, - finalize_preconditioned_gradient, - finalize_self_measurement_scores, - finalize_self_scores, get_tracked_module_names, prepare_modules, - release_scores, set_factors, set_gradient_scale, set_mode, @@ -75,7 +72,7 @@ def save_self_scores( Args: output_dir (Path): Directory to save the scores. - scores (FACTOR_TYPE): + scores (SCORE_TYPE): Dictionary of scores to save. partition (PARTITION_TYPE, optional): Partition information, if any. @@ -92,7 +89,7 @@ def save_self_scores( def load_self_scores( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, -) -> Dict[str, torch.Tensor]: +) -> SCORE_TYPE: """Loads self-influence scores from disk. Args: @@ -170,7 +167,7 @@ def compute_self_scores_with_loaders( Whether to disable the progress bar. Defaults to `False`. Returns: - Dict[str, torch.Tensor]: + SCORE_TYPE: A dictionary containing the module name and its self-influence scores. """ update_factor_args(model=model, factor_args=factor_args) @@ -185,7 +182,7 @@ def compute_self_scores_with_loaders( ) if len(loaded_factors) > 0: for name in loaded_factors: - set_factors(model=model, factor_name=name, factors=loaded_factors[name]) + set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) dataset_size = len(train_loader.dataset) @@ -215,7 +212,7 @@ def compute_self_scores_with_loaders( bar_format=TQDM_BAR_FORMAT, disable=not state.is_main_process or disable_tqdm, ) as pbar: - for batch in train_loader: + for index, batch in enumerate(train_loader): batch = send_to_device( tensor=batch, device=state.device, @@ -234,26 +231,31 @@ def compute_self_scores_with_loaders( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) - if score_args.compute_per_module_scores: - for module in cached_module_lst: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() - ) - else: - self_scores = None - for module in cached_module_lst: - if self_scores is None: - self_scores = torch.zeros_like( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + with torch.no_grad(): + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).to(device="cpu", copy=True) ) - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) - score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) - accumulate_iterations(model=model, tracked_module_names=tracked_module_names) - - if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: + else: + self_scores = None + for module in cached_module_lst: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + ) + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if ( + state.use_distributed + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 + and index not in [len(train_loader) - 1, len(train_loader) - 2] + ): state.wait_for_everyone() - del batch, loss + del loss total_steps += 1 pbar.update(1) @@ -276,9 +278,11 @@ def compute_self_scores_with_loaders( gather_list = None if state.is_main_process: gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) + dist.gather(total_scores[module_name], gather_list) if state.is_main_process: total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() + else: + total_scores[module_name] = total_scores[module_name].cpu() state.wait_for_everyone() return total_scores @@ -307,23 +311,23 @@ def compute_self_measurement_scores_with_loaders( model=model, factor_name=name, factors=loaded_factors[name], + clone=True, ) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + dataset_size = len(train_loader.dataset) score_chunks: Dict[str, List[torch.Tensor]] = {} if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name] = [] + for module in cached_module_lst: + score_chunks[module.name] = [] else: score_chunks[ALL_MODULE_NAME] = [] - cached_module_lst = [] - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - cached_module_lst.append(module) - total_steps = 0 enable_amp = score_args.amp_dtype is not None scaler = GradScaler(enabled=enable_amp) @@ -337,7 +341,7 @@ def compute_self_measurement_scores_with_loaders( bar_format=TQDM_BAR_FORMAT, disable=not state.is_main_process or disable_tqdm, ) as pbar: - for batch in train_loader: + for index, batch in enumerate(train_loader): batch = send_to_device( tensor=batch, device=state.device, @@ -377,25 +381,30 @@ def compute_self_measurement_scores_with_loaders( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) - del batch, loss + del loss - if score_args.compute_per_module_scores: - for module in cached_module_lst: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() - ) - else: - self_scores = None - for module in cached_module_lst: - if self_scores is None: - self_scores = torch.zeros_like( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + with torch.no_grad(): + if score_args.compute_per_module_scores: + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).to(device="cpu", copy=True) ) - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) - score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) - accumulate_iterations(model=model, tracked_module_names=tracked_module_names) - - if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: + else: + self_scores = None + for module in cached_module_lst: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + ) + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if ( + state.use_distributed + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 + and index not in [len(train_loader) - 1, len(train_loader) - 2] + ): state.wait_for_everyone() total_steps += 1 @@ -420,9 +429,11 @@ def compute_self_measurement_scores_with_loaders( gather_list = None if state.is_main_process: gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) + dist.gather(total_scores[module_name], gather_list) if state.is_main_process: total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() + else: + total_scores[module_name] = total_scores[module_name].cpu() state.wait_for_everyone() return total_scores diff --git a/kronfluence/task.py b/kronfluence/task.py index 33472a3..1c82dd5 100644 --- a/kronfluence/task.py +++ b/kronfluence/task.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union, Dict +from typing import Any, Dict, List, Optional, Union import torch from torch import nn diff --git a/tests/factors/test_lambdas.py b/tests/factors/test_lambdas.py index 751193c..d2f15e8 100644 --- a/tests/factors/test_lambdas.py +++ b/tests/factors/test_lambdas.py @@ -321,65 +321,6 @@ def test_lambda_matrices_max_examples( assert num_examples == max_examples -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "conv_bn", - ], -) -@pytest.mark.parametrize("module_partitions", [1, 2]) -@pytest.mark.parametrize("train_size", [100]) -@pytest.mark.parametrize("seed", [5]) -def test_lambda_matrices_amp( - test_name: str, - module_partitions: int, - train_size: int, - seed: int, -) -> None: - # Lambda matrices should be similar even when AMP is enabled. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = pytest_factor_arguments() - factor_args.lambda_module_partitions = module_partitions - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=8, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - factor_args.amp_dtype = torch.float16 - analyzer.fit_all_factors( - factors_name=custom_factors_name("amp"), - dataset=train_dataset, - per_device_batch_size=8, - overwrite_output_dir=True, - factor_args=factor_args, - dataloader_kwargs=kwargs, - ) - amp_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name("amp"), - ) - - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence(lambda_factors[name], amp_lambda_factors[name], atol=1e-01, rtol=1e-02) - - @pytest.mark.parametrize( "test_name", [ diff --git a/tests/gpu_tests/README.md b/tests/gpu_tests/README.md index dee6108..8f41e5e 100644 --- a/tests/gpu_tests/README.md +++ b/tests/gpu_tests/README.md @@ -49,7 +49,7 @@ python amp_test.py ### CPU Offload Test -To test if `cached_activation_cpu_offload` option is properly implemented, run: +To test if `offload_activations_to_cpu` option is properly implemented, run: ```bash pytest test_offload_cpu.py diff --git a/tests/gpu_tests/compile_test.py b/tests/gpu_tests/compile_test.py index ca9e008..076ad28 100644 --- a/tests/gpu_tests/compile_test.py +++ b/tests/gpu_tests/compile_test.py @@ -16,7 +16,7 @@ ) from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence, ATOL, RTOL +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence logging.basicConfig(level=logging.DEBUG) OLD_FACTOR_NAME = "single_gpu" diff --git a/tests/gpu_tests/cpu_test.py b/tests/gpu_tests/cpu_test.py index 93dde7a..31505af 100644 --- a/tests/gpu_tests/cpu_test.py +++ b/tests/gpu_tests/cpu_test.py @@ -16,7 +16,7 @@ ) from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence, ATOL, RTOL +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence logging.basicConfig(level=logging.DEBUG) OLD_FACTOR_NAME = "single_gpu" diff --git a/tests/gpu_tests/ddp_test.py b/tests/gpu_tests/ddp_test.py index 2cdc2e6..2ff0499 100644 --- a/tests/gpu_tests/ddp_test.py +++ b/tests/gpu_tests/ddp_test.py @@ -19,7 +19,7 @@ from kronfluence.utils.model import apply_ddp from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence, ATOL, RTOL +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence LOCAL_RANK = int(os.environ["LOCAL_RANK"]) WORLD_RANK = int(os.environ["RANK"]) diff --git a/tests/gpu_tests/fsdp_test.py b/tests/gpu_tests/fsdp_test.py index 6be6579..3335709 100644 --- a/tests/gpu_tests/fsdp_test.py +++ b/tests/gpu_tests/fsdp_test.py @@ -19,7 +19,7 @@ from kronfluence.utils.model import apply_fsdp from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence, ATOL, RTOL +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence LOCAL_RANK = int(os.environ["LOCAL_RANK"]) WORLD_RANK = int(os.environ["RANK"]) diff --git a/tests/gpu_tests/pipeline.py b/tests/gpu_tests/pipeline.py index 117ec52..76f542d 100644 --- a/tests/gpu_tests/pipeline.py +++ b/tests/gpu_tests/pipeline.py @@ -25,12 +25,12 @@ def compute_train_loss( if not sample: return F.cross_entropy(logits, labels, reduction="sum") with torch.no_grad(): - probs = torch.nn.functional.softmax(logits, dim=-1) + probs = torch.nn.functional.softmax(logits.detach(), dim=-1) sampled_labels = torch.multinomial( probs, num_samples=1, ).flatten() - return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum") + return F.cross_entropy(logits, sampled_labels, reduction="sum") def compute_measurement( self, diff --git a/tests/gpu_tests/prepare_tests.py b/tests/gpu_tests/prepare_tests.py index 6c71b41..a9a140b 100644 --- a/tests/gpu_tests/prepare_tests.py +++ b/tests/gpu_tests/prepare_tests.py @@ -6,8 +6,8 @@ from tqdm import tqdm from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset # Pick difficult cases where the dataset is not perfectly divisible by batch size. @@ -102,47 +102,89 @@ def run_analysis() -> None: overwrite_output_dir=True, ) - # score_args = ScoreArguments( - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # ) - # analyzer.compute_pairwise_scores( - # scores_name="single_gpu", - # factors_name="single_gpu", - # query_dataset=eval_dataset, - # train_dataset=train_dataset, - # per_device_query_batch_size=12, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # analyzer.compute_self_scores( - # scores_name="single_gpu", - # factors_name="single_gpu", - # train_dataset=train_dataset, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - - # score_args = ScoreArguments( - # query_gradient_rank=32, - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # query_gradient_svd_dtype=torch.float64, - # ) - # analyzer.compute_pairwise_scores( - # scores_name="single_gpu_qb", - # factors_name="single_gpu", - # query_dataset=eval_dataset, - # train_dataset=train_dataset, - # per_device_query_batch_size=12, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) + score_args = pytest_score_arguments() + analyzer.compute_pairwise_scores( + scores_name="single_gpu", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + analyzer.compute_self_scores( + scores_name="single_gpu", + factors_name="single_gpu", + train_dataset=train_dataset, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.use_measurement_for_self_influence = True + analyzer.compute_self_scores( + scores_name="single_gpu_measurement", + factors_name="single_gpu", + train_dataset=train_dataset, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.query_gradient_low_rank = 32 + analyzer.compute_pairwise_scores( + scores_name="single_gpu_qb", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.aggregate_train_gradients = True + analyzer.compute_pairwise_scores( + scores_name="single_gpu_train_agg", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.aggregate_query_gradients = True + analyzer.compute_pairwise_scores( + scores_name="single_gpu_query_agg", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + + score_args = pytest_score_arguments() + score_args.aggregate_train_gradients = True + score_args.aggregate_query_gradients = True + analyzer.compute_pairwise_scores( + scores_name="single_gpu_all_agg", + factors_name="single_gpu", + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) if __name__ == "__main__": diff --git a/tests/gpu_tests/test_offload_cpu.py b/tests/gpu_tests/test_offload_cpu.py index 0066270..fca8ef7 100644 --- a/tests/gpu_tests/test_offload_cpu.py +++ b/tests/gpu_tests/test_offload_cpu.py @@ -6,6 +6,8 @@ from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test @@ -23,13 +25,13 @@ "gpt", ], ) -@pytest.mark.parametrize("cached_activation_cpu_offload", [True, False]) +@pytest.mark.parametrize("offload_activations_to_cpu", [True, False]) @pytest.mark.parametrize("query_size", [16]) @pytest.mark.parametrize("train_size", [32]) @pytest.mark.parametrize("seed", [1]) def test_cpu_offloads( test_name: str, - cached_activation_cpu_offload: bool, + offload_activations_to_cpu: bool, query_size: int, train_size: int, seed: int, @@ -50,10 +52,10 @@ def test_cpu_offloads( disable_tqdm=True, ) factor_args = FactorArguments( - cached_activation_cpu_offload=cached_activation_cpu_offload, + offload_activations_to_cpu=offload_activations_to_cpu, ) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True factors_name = f"pytest_{test_name}_{test_cpu_offloads.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -65,7 +67,7 @@ def test_cpu_offloads( ) score_args = ScoreArguments( - cached_activation_cpu_offload=cached_activation_cpu_offload, + offload_activations_to_cpu=offload_activations_to_cpu, ) scores_name = f"pytest_{test_name}_{test_cpu_offloads.__name__}_scores" analyzer.compute_pairwise_scores( @@ -122,15 +124,9 @@ def test_cpu_offloads_identical( disable_model_save=True, disable_tqdm=True, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - cached_activation_cpu_offload=False, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True factors_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -140,13 +136,7 @@ def test_cpu_offloads_identical( factor_args=factor_args, overwrite_output_dir=True, ) - score_args = ScoreArguments( - cached_activation_cpu_offload=False, - per_sample_gradient_dtype=torch.float64, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_module_score=per_module_score, - ) + score_args = pytest_score_arguments() scores_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_scores" analyzer.compute_pairwise_scores( scores_name=scores_name, @@ -162,7 +152,7 @@ def test_cpu_offloads_identical( pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) factors_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_cached" - factor_args.cached_activation_cpu_offload = True + factor_args.offload_activations_to_cpu = True analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -171,7 +161,7 @@ def test_cpu_offloads_identical( factor_args=factor_args, overwrite_output_dir=True, ) - score_args.cached_activation_cpu_offload = True + score_args.offload_activations_to_cpu = True scores_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_cached_scores" analyzer.compute_pairwise_scores( scores_name=scores_name, diff --git a/tests/modules/test_matmul.py b/tests/modules/test_matmul.py index 42b38a4..860c4b3 100644 --- a/tests/modules/test_matmul.py +++ b/tests/modules/test_matmul.py @@ -3,6 +3,7 @@ import torch from accelerate.utils import set_seed from opt_einsum import DynamicProgramming +import time def test_query_gradient_svd( @@ -180,3 +181,56 @@ def test_compute_score_matmul( optimize=DynamicProgramming(search_outer=True, minimize="flops"), ) print(path) + + +def test_precondition_gradient( + seed: int = 0, +) -> None: + input_dim = 128 + output_dim = 256 + batch_dim = 8 + lambda_scale = 1000 + damping = 1e-08 + + set_seed(seed) + A = torch.rand(size=(input_dim, input_dim), dtype=torch.float64) + B = torch.rand(size=(output_dim, output_dim), dtype=torch.float64) + Lambda = torch.rand(size=(output_dim, input_dim), dtype=torch.float64) + gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) + + start_time = time.time() + rotated_gradient = torch.einsum( + "ij,bjl,lk->bik", + ( + B.t(), + gradient, + A, + ), + ) + rotated_gradient.div_(Lambda + damping) + results = lambda_scale * torch.einsum( + "ij,bjl,lk->bik", + (B, rotated_gradient, A.t()), + ) + print(f"Took {time.time() - start_time} seconds.") + + start_time = time.time() + grads_rot = torch.matmul( + B.t(), + torch.matmul( + gradient, + A, + ), + ) + scaled_lambda = Lambda / lambda_scale + grads_rot.div_(scaled_lambda) + raw_results = torch.matmul( + B, + torch.matmul( + grads_rot, + A.t(), + ), + ) + print(f"Took {time.time() - start_time} seconds.") + + assert torch.allclose(raw_results, results, atol=1e-5, rtol=1e-3) diff --git a/tests/modules/test_modules.py b/tests/modules/test_modules.py index b6863cd..7e60489 100644 --- a/tests/modules/test_modules.py +++ b/tests/modules/test_modules.py @@ -23,6 +23,7 @@ ModuleMode.COVARIANCE, ModuleMode.LAMBDA, ModuleMode.PRECONDITION_GRADIENT, + ModuleMode.GRADIENT_AGGREGATION, ], ) @pytest.mark.parametrize("train_size", [32]) @@ -82,6 +83,7 @@ def test_tracked_modules_forward_equivalence( ModuleMode.COVARIANCE, ModuleMode.LAMBDA, ModuleMode.PRECONDITION_GRADIENT, + ModuleMode.GRADIENT_AGGREGATION, ], ) @pytest.mark.parametrize("train_size", [32]) @@ -126,7 +128,8 @@ def test_tracked_modules_backward_equivalence( wrapped_loss = task.compute_train_loss(batch, wrapped_model, sample=False) wrapped_loss.backward() for name, param in wrapped_model.named_parameters(): - wrapped_grads[name] = param.grad.detach() + if param.grad is not None: + wrapped_grads[name] = param.grad.detach() for name, grad in wrapped_grads.items(): original_name = name.replace(".original_module", "") diff --git a/tests/modules/test_per_sample_gradients.py b/tests/modules/test_per_sample_gradients.py index f2ae5e5..84a9481 100644 --- a/tests/modules/test_per_sample_gradients.py +++ b/tests/modules/test_per_sample_gradients.py @@ -393,56 +393,3 @@ def test_lambda_equivalence( atol=ATOL, rtol=RTOL, ) - - -def test_precondition_gradient( - seed: int = 0, -) -> None: - input_dim = 128 - output_dim = 256 - batch_dim = 8 - lambda_scale = 1000 - damping = 1e-08 - - set_seed(seed) - A = torch.rand(size=(input_dim, input_dim), dtype=torch.float64) - B = torch.rand(size=(output_dim, output_dim), dtype=torch.float64) - Lambda = torch.rand(size=(output_dim, input_dim), dtype=torch.float64) - gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) - - start_time = time.time() - rotated_gradient = torch.einsum( - "ij,bjl,lk->bik", - ( - B.t(), - gradient, - A, - ), - ) - rotated_gradient.div_(Lambda + damping) - results = lambda_scale * torch.einsum( - "ij,bjl,lk->bik", - (B, rotated_gradient, A.t()), - ) - print(f"Took {time.time() - start_time} seconds.") - - start_time = time.time() - grads_rot = torch.matmul( - B.t(), - torch.matmul( - gradient, - A, - ), - ) - scaled_lambda = Lambda / lambda_scale - grads_rot.div_(scaled_lambda) - raw_results = torch.matmul( - B, - torch.matmul( - grads_rot, - A.t(), - ), - ) - print(f"Took {time.time() - start_time} seconds.") - - assert torch.allclose(raw_results, results, atol=1e-5, rtol=1e-3) diff --git a/tests/scores/test_pairwise_scores.py b/tests/scores/test_pairwise_scores.py index 85f0709..c3d412d 100644 --- a/tests/scores/test_pairwise_scores.py +++ b/tests/scores/test_pairwise_scores.py @@ -27,12 +27,12 @@ "test_name", [ "mlp", - # "repeated_mlp", - # "conv", - # "bert", - # "roberta", - # "gpt", - # "gpt_checkpoint", + "repeated_mlp", + "conv", + "bert", + "roberta", + "gpt", + "gpt_checkpoint", ], ) @pytest.mark.parametrize("score_dtype", [torch.float32]) @@ -96,7 +96,6 @@ def test_compute_pairwise_scores( @pytest.mark.parametrize("test_name", ["mlp"]) -@pytest.mark.parametrize("einsum_minimize_size", [True, False]) @pytest.mark.parametrize("has_shared_parameters", [True, False]) @pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.float16]) @@ -109,7 +108,6 @@ def test_compute_pairwise_scores( def test_compute_pairwise_scores_dtype( test_name: str, has_shared_parameters: bool, - einsum_minimize_size: bool, per_sample_gradient_dtype: torch.dtype, precondition_dtype: torch.dtype, score_dtype: torch.dtype, @@ -145,7 +143,6 @@ def test_compute_pairwise_scores_dtype( score_args = ScoreArguments( damping_factor=damping_factor, - einsum_minimize_size=einsum_minimize_size, score_dtype=score_dtype, query_gradient_low_rank=query_gradient_low_rank, per_sample_gradient_dtype=per_sample_gradient_dtype, @@ -655,9 +652,7 @@ def test_query_accumulation_steps( @pytest.mark.parametrize( "test_name", [ - "mlp", - # "repeated_mlp", - # "roberta", + "mlp", "conv" ], ) @pytest.mark.parametrize("query_size", [50]) @@ -745,9 +740,7 @@ def test_query_gradient_aggregation( @pytest.mark.parametrize( "test_name", [ - "mlp", - # "conv_bn", - # "gpt", + "mlp", "conv" ], ) @pytest.mark.parametrize("query_size", [64]) diff --git a/tests/scores/test_self_scores.py b/tests/scores/test_self_scores.py index 619f147..c0ee190 100644 --- a/tests/scores/test_self_scores.py +++ b/tests/scores/test_self_scores.py @@ -163,8 +163,7 @@ def test_compute_self_scores_dtype( "conv_bn", ], ) -# @pytest.mark.parametrize("strategy", ["identity", "diagonal", "kfac", "ekfac"]) -@pytest.mark.parametrize("strategy", ["ekfac"]) +@pytest.mark.parametrize("strategy", ["identity", "diagonal", "kfac", "ekfac"]) @pytest.mark.parametrize("train_size", [49]) @pytest.mark.parametrize("seed", [2]) def test_self_scores_batch_size_equivalence( diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index e5293ea..885b6c9 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -129,7 +129,6 @@ def test_default_score_arguments() -> None: assert score_args.damping_factor == 1e-08 assert score_args.amp_dtype is None assert score_args.offload_activations_to_cpu is False - assert score_args.einsum_minimize_size is False assert score_args.data_partitions == 1 assert score_args.module_partitions == 1 @@ -146,6 +145,6 @@ def test_default_score_arguments() -> None: assert score_args.use_measurement_for_self_influence is False assert score_args.query_gradient_svd_dtype == torch.float32 - assert score_args.score_dtype == torch.float32 assert score_args.per_sample_gradient_dtype == torch.float32 assert score_args.precondition_dtype == torch.float32 + assert score_args.score_dtype == torch.float32 diff --git a/tests/utils.py b/tests/utils.py index d49a470..03ca290 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -207,6 +207,6 @@ def reshape_parameter_gradient_to_module_matrix( if remove_gradient: del gradient_dict[module_name + ".bias"] else: - error_msg = f"Unsupported module type: {type(module)}. Only nn.Linear or nn.Conv2d are supported." + error_msg = f"Unsupported module type: {type(module)}. Only `nn.Linear` or `nn.Conv2d` are supported." raise UnsupportableModuleError(error_msg) return gradient_matrix