diff --git a/singd/structures/base.py b/singd/structures/base.py index 6f622c3..0680b0d 100644 --- a/singd/structures/base.py +++ b/singd/structures/base.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist from torch import Tensor, zeros +from torch.linalg import matrix_norm from singd.structures.utils import diag_add_, supported_eye @@ -343,6 +344,15 @@ def infinity_vector_norm(self) -> Tensor: # NOTE `.max` can only be called on tensors with non-zero shape return max(t.abs().max() for _, t in self.named_tensors() if t.numel() > 0) + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + self._warn_naive_implementation("frobenius_norm") + return matrix_norm(self.to_dense()) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/singd/structures/blockdiagonal.py b/singd/structures/blockdiagonal.py index 2c81b6c..7fb8f1b 100644 --- a/singd/structures/blockdiagonal.py +++ b/singd/structures/blockdiagonal.py @@ -7,6 +7,7 @@ import torch from einops import rearrange from torch import Tensor, arange, cat, einsum, zeros +from torch.linalg import vector_norm from singd.structures.base import StructuredMatrix from singd.structures.utils import lowest_precision, supported_eye @@ -313,6 +314,16 @@ def diag_add_(self, value: float) -> BlockDiagonalMatrixTemplate: return self + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + return vector_norm( + cat([t.flatten() for _, t in self.named_tensors() if t.numel() > 0]) + ) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/singd/structures/diagonal.py b/singd/structures/diagonal.py index 74e31ee..affa493 100644 --- a/singd/structures/diagonal.py +++ b/singd/structures/diagonal.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, einsum, ones, zeros +from torch.linalg import vector_norm from singd.structures.base import StructuredMatrix @@ -179,6 +180,14 @@ def diag_add_(self, value: float) -> DiagonalMatrix: self._mat_diag.add_(value) return self + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + return vector_norm(self._mat_diag) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/singd/structures/hierarchical.py b/singd/structures/hierarchical.py index 95c43a6..5cb8489 100644 --- a/singd/structures/hierarchical.py +++ b/singd/structures/hierarchical.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, arange, cat, einsum, ones, zeros +from torch.linalg import vector_norm from singd.structures.base import StructuredMatrix from singd.structures.utils import diag_add_, lowest_precision, supported_eye @@ -353,6 +354,16 @@ def diag_add_(self, value: float) -> HierarchicalMatrixTemplate: diag_add_(self.E, value) return self + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + return vector_norm( + cat([t.flatten() for _, t in self.named_tensors() if t.numel() > 0]) + ) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/singd/structures/recursive.py b/singd/structures/recursive.py index 80e277c..f9d6225 100644 --- a/singd/structures/recursive.py +++ b/singd/structures/recursive.py @@ -39,18 +39,47 @@ def register_substructure(self, substructure: StructuredMatrix, name: str) -> No setattr(self, name, substructure) self._substructure_names.append(name) - def named_tensors(self) -> Iterator[Tuple[str, Tensor]]: + def named_tensors( + self, include_substructures: bool = True + ) -> Iterator[Tuple[str, Tensor]]: """Yield all tensors that represent the matrix and their names. + Args: + include_substructures: If `True`, also include the tensors of the + substructures. If `False`, exclude them. Default is `True`. + Yields: A tuple of the tensor's name and the tensor itself. """ for name in self._tensor_names: yield name, getattr(self, name) + if include_substructures: + for subname, substructure in self.named_substructures(): + for name, tensor in substructure.named_tensors(): + yield f"{name}.{subname}", tensor + + def named_substructures(self) -> Iterator[Tuple[str, StructuredMatrix]]: + """Yield all substructures and their names. + + Yields: + A tuple of the substructure's name and the substructure itself. + """ for name in self._substructure_names: - substructure = getattr(self, name) - for sub_name, tensor in substructure.named_tensors(): - yield f"{name}.{sub_name}", tensor + yield name, getattr(self, name) + + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + fro_squared = sum( + (t**2).sum() for _, t in self.named_tensors(include_substructures=False) + ) + fro_squared_sub = sum( + s.frobenius_norm() ** 2 for _, s in self.named_substructures() + ) + return (fro_squared + fro_squared_sub).sqrt() class RecursiveTopRightMatrixTemplate(RecursiveStructuredMatrix): diff --git a/singd/structures/triltoeplitz.py b/singd/structures/triltoeplitz.py index 3e85b5d..dd29863 100644 --- a/singd/structures/triltoeplitz.py +++ b/singd/structures/triltoeplitz.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, arange, cat, zeros +from torch.linalg import vector_norm from torch.nn.functional import conv1d, pad from singd.structures.base import StructuredMatrix @@ -191,6 +192,22 @@ def diag_add_(self, value: float) -> TrilToeplitzMatrix: self._lower_diags[0].add_(value) return self + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + (dim,) = self._lower_diags.shape + multiplicity = arange( + dim, + 0, + step=-1, + dtype=self._lower_diags.dtype, + device=self._lower_diags.device, + ) + return vector_norm(self._lower_diags * multiplicity.sqrt()) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/singd/structures/triutoeplitz.py b/singd/structures/triutoeplitz.py index 832a12e..1a6205d 100644 --- a/singd/structures/triutoeplitz.py +++ b/singd/structures/triutoeplitz.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, arange, cat, triu_indices, zeros +from torch.linalg import vector_norm from torch.nn.functional import conv1d, pad from singd.structures.base import StructuredMatrix @@ -189,6 +190,22 @@ def diag_add_(self, value: float) -> TriuToeplitzMatrix: self._upper_diags[0].add_(value) return self + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + (dim,) = self._upper_diags.shape + multiplicity = arange( + dim, + 0, + step=-1, + dtype=self._upper_diags.dtype, + device=self._upper_diags.device, + ) + return vector_norm(self._upper_diags * multiplicity.sqrt()) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/test/structures/utils.py b/test/structures/utils.py index b3c69df..145a738 100644 --- a/test/structures/utils.py +++ b/test/structures/utils.py @@ -11,7 +11,7 @@ from matplotlib import pyplot as plt from pytest import mark from torch import Tensor, device, manual_seed, rand, zeros -from torch.linalg import vector_norm +from torch.linalg import matrix_norm, vector_norm from singd.structures.base import StructuredMatrix from singd.structures.utils import is_half_precision, supported_eye @@ -585,6 +585,22 @@ def test_infinity_vector_norm(self, dev: device, dtype: torch.dtype): structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat) report_nonclose(truth, structured.infinity_vector_norm()) + @mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS) + @mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) + def test_frobenius_norm(self, dev: device, dtype: torch.dtype): + """Test Frobenius norm of a structured matrix. + + Args: + dev: The device on which to run the test. + dtype: The data type of the matrices. + """ + for dim in self.DIMS: + manual_seed(0) + sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype)) + truth = matrix_norm(self.project(sym_mat)) + structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat) + report_nonclose(truth, structured.frobenius_norm()) + @mark.expensive def test_visual(self): """Create pictures and animations of the structure.