Skip to content

Commit

Permalink
Finalize refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 8, 2024
1 parent 398503c commit 133f53c
Show file tree
Hide file tree
Showing 34 changed files with 783 additions and 689 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ jobs:
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_matmul.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
15 changes: 7 additions & 8 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
if damping_factor is None:
damping_factor = 0.1 * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu")
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()

def precondition_gradient(
self,
Expand Down Expand Up @@ -249,18 +249,18 @@ def requires_lambda_matrices_for_precondition(self) -> bool:
def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None:
storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
)
).contiguous()
storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
)
).contiguous()
activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device)
gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device)
lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0)
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = 0.1 * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu")
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[ACTIVATION_EIGENVALUES_NAME] = None
storage[GRADIENT_EIGENVALUES_NAME] = None

Expand Down Expand Up @@ -316,20 +316,19 @@ def requires_lambda_matrices_for_precondition(self) -> bool:
def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None:
storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
)
).contiguous()
storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
)
).contiguous()
storage[ACTIVATION_EIGENVALUES_NAME] = None
storage[GRADIENT_EIGENVALUES_NAME] = None

lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = 0.1 * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu")
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()

@torch.no_grad()
def precondition_gradient(
Expand Down
2 changes: 1 addition & 1 deletion kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def fit_covariance_matrices_with_loader(
state.wait_for_everyone()

num_data_processed.add_(find_batch_size(data=batch))
del batch, attention_mask, loss
del loss
total_steps += 1
pbar.update(1)

Expand Down
2 changes: 1 addition & 1 deletion kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def fit_lambda_matrices_with_loader(
state.wait_for_everyone()

num_data_processed.add_(find_batch_size(data=batch))
del batch, loss
del loss
total_steps += 1
pbar.update(1)

Expand Down
6 changes: 2 additions & 4 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
summed_gradient = contract("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0)
return summed_gradient.view((1, *summed_gradient.size()))
return summed_gradient

def compute_per_sample_gradient(
self,
Expand All @@ -176,14 +176,12 @@ def compute_per_sample_gradient(
)
return per_sample_gradient

@torch.no_grad()
def compute_pairwise_score(
self, preconditioned_gradient, input_activation: torch.Tensor, output_gradient: torch.Tensor
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")

if isinstance(preconditioned_gradient, list):
left_mat, right_mat = preconditioned_gradient
if self.einsum_expression is None:
Expand Down
2 changes: 1 addition & 1 deletion kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def compute_per_sample_gradient(
return per_sample_gradient

def compute_pairwise_score(
self, preconditioned_gradient, input_activation: torch.Tensor, output_gradient: torch.Tensor
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
if isinstance(preconditioned_gradient, list):
Expand Down
112 changes: 96 additions & 16 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@


class ModuleMode(str, BaseEnum):
"""Enum representing a module's mode, indicating which factors and scores
need to be computed during forward and backward passes."""
"""Enum representing a module's mode for factor and score computation.
This enum indicates which factors and scores need to be computed during
forward and backward passes.
"""

DEFAULT = "default"
COVARIANCE = "covariance"
Expand Down Expand Up @@ -124,7 +127,7 @@ def __init__(
self._initialize_storage()

def _initialize_storage(self) -> None:
"""Initializes trackers for different module modes."""
"""Initializes storage for various factors and scores."""

# Storage for activation and pseudo-gradient covariance matrices #
for covariance_factor_name in COVARIANCE_FACTOR_NAMES:
Expand All @@ -138,56 +141,118 @@ def _initialize_storage(self) -> None:
for lambda_factor_name in LAMBDA_FACTOR_NAMES:
self.storage[lambda_factor_name]: Optional[torch.Tensor] = None

# Storage for preconditioned query gradients and influence scores #
# Storage for preconditioned gradients and influence scores #
self.storage[AGGREGATED_GRADIENT_NAME]: Optional[torch.Tensor] = None
self.storage[PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None
self.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None
self.storage[PAIRWISE_SCORE_MATRIX_NAME]: Optional[torch.Tensor] = None
self.storage[SELF_SCORE_VECTOR_NAME]: Optional[torch.Tensor] = None

def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
"""A forward pass of the tracked module. This should have identical behavior to that of the original module."""
def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Performs a forward pass of the tracked module.
This method should have identical behavior to that of the original module.
Args:
inputs (torch.Tensor):
Input tensor to the module.
*args:
Variable length argument list.
**kwargs:
Arbitrary keyword arguments.
Returns:
torch.Tensor:
The output of the forward pass.
"""
outputs = self.original_module(inputs, *args, **kwargs)
if outputs.requires_grad:
return outputs
return outputs + self._constant

def prepare_storage(self, device: torch.device) -> None:
"""Performs any necessary operations on storage before computing influence scores."""
"""Prepares storage for computing influence scores.
This method performs necessary operations on storage before computing influence scores.
Args:
device (torch.device):
The device to prepare the storage for.
"""
FactorConfig.CONFIGS[self.factor_args.strategy].prepare(
storage=self.storage,
score_args=self.score_args,
device=device,
)

def update_factor_args(self, factor_args: FactorArguments) -> None:
"""Updates the factor arguments."""
"""Updates the factor arguments.
Args:
factor_args (FactorArguments):
New factor arguments to set.
"""
self.factor_args = factor_args

def update_score_args(self, score_args: ScoreArguments) -> None:
"""Updates the score arguments."""
"""Updates the score arguments.
Args:
score_args (ScoreArguments):
New score arguments to set.
"""
self.score_args = score_args

def get_factor(self, factor_name: str) -> Optional[torch.Tensor]:
"""Returns the factor with the given name."""
"""Retrieves a factor by name from storage.
Args:
factor_name (str):
The name of the factor to retrieve.
Returns:
Optional[torch.Tensor]:
The requested factor, or `None` if not found.
"""
if factor_name not in self.storage or self.storage[factor_name] is None:
return None
return self.storage[factor_name]

def release_factor(self, factor_name: str) -> None:
"""Release the factor with the given name from memory."""
"""Releases a factor from memory.
Args:
factor_name (str):
The name of the factor to release.
"""
if factor_name not in self.storage or self.storage[factor_name] is None:
return None
del self.storage[factor_name]
self.storage[factor_name] = None

def set_factor(self, factor_name: str, factor: Any) -> None:
"""Sets the factor with the given name."""
"""Sets a factor in storage.
Args:
factor_name (str):
The name of the factor to set.
factor (Any):
The factor value to store.
"""
if factor_name in self.storage:
self.storage[factor_name] = factor

def set_mode(self, mode: ModuleMode, release_memory: bool = False) -> None:
"""Sets the module mode of the current `TrackedModule` instance."""
"""Sets the operating mode of the `TrackedModule`.
This method changes the current mode and manages associated trackers and memory.
Args:
mode (ModuleMode):
The new mode to set.
release_memory (bool):
Whether to release memory for all trackers.
"""
self._trackers[self.current_mode].release_hooks()
self.einsum_expression = None
self.current_mode = mode
Expand All @@ -199,19 +264,34 @@ def set_mode(self, mode: ModuleMode, release_memory: bool = False) -> None:
self._trackers[self.current_mode].register_hooks()

def set_attention_mask(self, attention_mask: Optional[torch.Tensor] = None) -> None:
"""Sets the attention mask for activation covariance computations."""
"""Sets the attention mask for activation covariance computations.
Args:
attention_mask (torch.Tensor, optional):
The attention mask to set.
"""
self.attention_mask = attention_mask

def set_gradient_scale(self, scale: float = 1.0) -> None:
"""Sets the scale of the gradient obtained from `GradScaler`."""
"""Sets the scale of the gradient obtained from `GradScaler`.
Args:
scale (float):
The scale factor to set.
"""
self.gradient_scale = scale

def finalize_iteration(self) -> None:
"""Finalizes statistics for the current iteration."""
self._trackers[self.current_mode].finalize_iteration()

def exist(self) -> bool:
"""Checks if the desired statistics are available."""
"""Checks if the desired statistics are available.
Returns:
bool:
`True` if statistics exist, `False` otherwise.
"""
return self._trackers[self.current_mode].exist()

def synchronize(self, num_processes: int) -> None:
Expand Down
7 changes: 6 additions & 1 deletion kronfluence/module/tracker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def finalize_iteration(self) -> None:
"""Finalizes statistics for the current iteration."""

def exist(self) -> bool:
"""Checks if the desired statistics are available."""
"""Checks if the desired statistics are available.
Returns:
bool:
`True` if statistics exist, `False` otherwise.
"""
return False

def synchronize(self, num_processes: int) -> None:
Expand Down
Loading

0 comments on commit 133f53c

Please sign in to comment.