From 5316d7d39558d2b8e2916ef6e3e2ec30a4971949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladimir=20Vargas=20Calder=C3=B3n?= Date: Thu, 18 Dec 2025 18:56:31 -0500 Subject: [PATCH 1/2] Add Givens orthogonal layer --- dwave/plugins/torch/nn/modules/__init__.py | 1 + dwave/plugins/torch/nn/modules/orthogonal.py | 188 +++++++++++++++++++ tests/helper_models.py | 88 +++++++++ tests/test_nn.py | 98 ++++++++-- 4 files changed, 355 insertions(+), 20 deletions(-) create mode 100644 dwave/plugins/torch/nn/modules/orthogonal.py create mode 100644 tests/helper_models.py diff --git a/dwave/plugins/torch/nn/modules/__init__.py b/dwave/plugins/torch/nn/modules/__init__.py index 598a19c..1362c2f 100755 --- a/dwave/plugins/torch/nn/modules/__init__.py +++ b/dwave/plugins/torch/nn/modules/__init__.py @@ -14,4 +14,5 @@ # from dwave.plugins.torch.nn.modules.linear import * +from dwave.plugins.torch.nn.modules.orthogonal import * from dwave.plugins.torch.nn.modules.utils import * diff --git a/dwave/plugins/torch/nn/modules/orthogonal.py b/dwave/plugins/torch/nn/modules/orthogonal.py new file mode 100644 index 0000000..880b5fe --- /dev/null +++ b/dwave/plugins/torch/nn/modules/orthogonal.py @@ -0,0 +1,188 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + +import torch +import torch.nn as nn +from einops import einsum + +__all__ = ["get_blocks_edges", "GivensRotationLayer"] + + +def get_blocks_edges(n: int) -> list[list[tuple[int, int]]]: + """ + Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens + rotations. + + A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs + in the same block can be rotated in parallel since they commute. + + Args: + n (int): Dimension of the vector space onto which an orthogonal layer will be built. + + Returns: + list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations. + """ + + assert n % 2 == 0, "n must be even" # TODO: discuss odd case with Firas + + def circle_method(sequence): + seq_first_half = sequence[: len(sequence) // 2] + seq_second_half = sequence[len(sequence) // 2 :][::-1] + return list(zip(seq_first_half, seq_second_half)) + + blocks = [] + sequence = list(range(n)) + seqdeque = deque(sequence[1:]) + for _ in range(n - 1): + blocks.append(circle_method(sequence)) + seqdeque.rotate(1) + sequence[1:] = list(seqdeque) + return blocks + + +class RoundRobinGivens(torch.autograd.Function): + """ + Implements custom forward and backward passes to implement the parallel algorithms in + https://arxiv.org/abs/2106.00003 + """ + + @staticmethod + def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor: + """ + Creates a rotation matrix in n dimensions using parallel Givens transformations by blocks. + + Args: + ctx (context): Stores information for backward propagation. + angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations + between pairs of dimensions. + blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that + specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2 + pairs of independent rotations. + n (int): Dimension of the space. + + Returns: + torch.Tensor: The nxn rotation matrix. + """ + # Blocks is of shape (n_blocks, n/2, 2) containing indices for angles + # Within each block, each Givens rotation is commuting, so we can apply them in parallel + U = torch.eye(n, device=angles.device) + block_size = n // 2 + idx_block = torch.arange(block_size, device=angles.device) + for b, block in enumerate(blocks): + # angles is of shape (n_angles,) containing all angles for contiguous blocks. + angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,) + c = torch.cos(angles_in_block) + s = torch.sin(angles_in_block) + i_idx = block[:, 0] + j_idx = block[:, 1] + r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx] + r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx] + U[:, i_idx] = r_i + U[:, j_idx] = r_j + ctx.save_for_backward(angles, blocks, U) + return U + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + Computes the VJP needed for backward propagation. + + Args: + ctx (context): Contains information for backward propagation. + grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss + with respect to the output of the forward pass, i.e., dL/dU. + + Returns: + torch.Tensor: The gradient of the loss with respect to the input angles. + """ + angles, blocks, Ufwd_saved = ctx.saved_tensors + Ufwd = Ufwd_saved.clone() + M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n) + n = M.size(1) + block_size = n // 2 + A = torch.zeros((block_size, n), device=grad_output.device) + grad_theta = torch.zeros_like(angles) + idx_block = torch.arange(block_size, device=grad_output.device) + for b, block in enumerate(blocks): + i_idx = block[:, 0] + j_idx = block[:, 1] + angles_in_block = angles[idx_block + b * block_size] # shape (n/2,) + c = torch.cos(angles_in_block) + s = torch.sin(angles_in_block) + r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx] + r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx] + Ufwd[i_idx] = r_i + Ufwd[j_idx] = r_j + r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx] + r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx] + M[:, i_idx] = r_i + M[:, j_idx] = r_j + A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx] + grad_theta[idx_block + b * block_size] = A.sum(dim=1) + return grad_theta, None, None + + +class GivensRotationLayer(nn.Module): + """ + An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in a + round-robin fashion. + + Angles are arranged into blocks, where each block references rotations that can be applied in + parallel because these rotations commute. + + Args: + n (int): Dimension of the input and output space. + bias (bool): If True, adds a learnable bias to the output. Default: True. + """ + + def __init__(self, n: int, bias: bool = True): + super().__init__() + assert n % 2 == 0, "n must be even" # TODO: discuss odd case with Firas + self.n = n + self.n_angles = n * (n - 1) // 2 + self.angles = nn.Parameter(torch.randn(self.n_angles)) + blocks_edges = get_blocks_edges(n) + self.register_buffer( + "blocks", + torch.tensor(blocks_edges, dtype=torch.long), + ) + if bias: + self.bias = nn.Parameter(torch.zeros(n)) + else: + self.register_parameter("bias", None) + + def _create_rotation_matrix(self) -> torch.Tensor: + """ + Computes the Givens rotation matrix. + """ + U = RoundRobinGivens.apply(self.angles, self.blocks, self.n) + return U + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies the Givens rotation to the input tensor ``x``. + + Args: + x (torch.Tensor): Input tensor of shape (..., n). + + Returns: + torch.Tensor: Rotated tensor of shape (..., n). + """ + U = self._create_rotation_matrix() + rotated_x = einsum(x, U, "... i, o i -> ... o") + if self.bias is not None: + rotated_x = rotated_x + self.bias + return rotated_x diff --git a/tests/helper_models.py b/tests/helper_models.py new file mode 100644 index 0000000..b8a9104 --- /dev/null +++ b/tests/helper_models.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from einops import einsum + + +class NaiveGivensRotationLayer(nn.Module): + """ + Naive implementation of a Givens rotation layer. + + Sequentially applies all Givens rotations to implement an orthogonal transformation in an order + provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains + pairs of indices such that no index appears more than once in a block. However, this + implementation does not rely on that assumption, so that indeces can appear multiple times in a + block; however, all pairs of indices must appear exactly once in the entire blocks tensor. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): If True, adds a learnable bias to the output. Default: True. + + Note: + This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to + out_features. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + assert in_features == out_features, ( + "This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to " + "out_features." + ) + self.n = in_features + self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None): + """ + Creates the rotation matrix from the Givens angles by applying the Givens rotations in order + and sequentially, as specified by blocks. + + Args: + angles (torch.Tensor): Givens rotation angles. + blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If + None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create + the blocks. Defaults to None. + + Returns: + torch.Tensor: Rotation matrix. + """ + block_size = self.n // 2 + if blocks is None: + # Create dummy blocks from triu indices: + triu_indices = torch.triu_indices(self.n, self.n, offset=1) + blocks = triu_indices.t().view(-1, block_size, 2) + U = torch.eye(self.n) + for b, block in enumerate(blocks): + for k in range(block_size): + i = block[k, 0].item() + j = block[k, 1].item() + angle = angles[b * block_size + k] + c = torch.cos(angle) + s = torch.sin(angle) + # Need to clone because of pytorch. (This wouldn't happen in JAX) + r_i = c * U[:, i].clone() + s * U[:, j].clone() + r_j = -s * U[:, i].clone() + c * U[:, j].clone() + U[:, i] = r_i + U[:, j] = r_j + return U + + def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor: + """ + Applies the Givens rotation to the input tensor ``x``. + + Args: + x (torch.Tensor): Input tensor of shape (..., n). + blocks (torch.Tensor): Blocks specifying the order of rotations. + + Returns: + torch.Tensor: Rotated tensor of shape (..., n). + """ + W = self._create_rotation_matrix(self.angles, blocks) + x = einsum(x, W, "... i, o i -> ... o") + if self.bias is not None: + x = x + self.bias + return x diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..7cdf6cc 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -3,37 +3,46 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn import LinearBlock, SkipLinear, store_config +from dwave.plugins.torch.nn import GivensRotationLayer, LinearBlock, SkipLinear, store_config from tests.helper_functions import model_probably_good +from tests.helper_models import NaiveGivensRotationLayer class TestUtils(unittest.TestCase): def test_store_config(self): with self.subTest("Simple case"): + class MyModel(torch.nn.Module): @store_config - def __init__(self, a, b=1, *, x=4, y='hello'): + def __init__(self, a, b=1, *, x=4, y="hello"): super().__init__() model = MyModel(a=123, x=5) - self.assertDictEqual(dict(model.config), - {"a": 123, "b": 1, "x": 5, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), + {"a": 123, "b": 1, "x": 5, "y": "hello", "module_name": "MyModel"}, + ) model = MyModel(456) - self.assertDictEqual(dict(model.config), - {"a": 456, "b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), + {"a": 456, "b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}, + ) with self.subTest("Case with default args"): + class MyModel(torch.nn.Module): @store_config - def __init__(self, b=1, x=4, y='hello'): + def __init__(self, b=1, x=4, y="hello"): super().__init__() model = MyModel() - self.assertDictEqual(dict(model.config), - {"b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), {"b": 1, "x": 4, "y": "hello", "module_name": "MyModel"} + ) with self.subTest("Empty config case failed."): + class MyModel(torch.nn.Module): @store_config def __init__(self): @@ -45,7 +54,7 @@ def __init__(self): def test_store_config_nested(self): class InnerModel(torch.nn.Module): @store_config - def __init__(self, a, b=1, *, x=4, y='hello'): + def __init__(self, a, b=1, *, x=4, y="hello"): super().__init__() class OuterModel(torch.nn.Module): @@ -56,14 +65,63 @@ def __init__(self, module_1, module_2=None): module_1 = InnerModel(a=123, x=5) module_2 = InnerModel(a="second", y="lol") model = OuterModel(module_1, module_2) - self.assertDictEqual(dict(model.config), - {"module_1": module_1.config, - "module_2": module_2.config, - "module_name": "OuterModel"}) - self.assertDictEqual(dict(model.config["module_1"]), - dict(a=123, b=1, x=5, y="hello", module_name="InnerModel")) - self.assertDictEqual(dict(model.config["module_2"]), - dict(a="second", b=1, x=4, y="lol", module_name="InnerModel")) + self.assertDictEqual( + dict(model.config), + {"module_1": module_1.config, "module_2": module_2.config, "module_name": "OuterModel"}, + ) + self.assertDictEqual( + dict(model.config["module_1"]), + dict(a=123, b=1, x=5, y="hello", module_name="InnerModel"), + ) + self.assertDictEqual( + dict(model.config["module_2"]), + dict(a="second", b=1, x=4, y="lol", module_name="InnerModel"), + ) + + +class TestOrthogonal(unittest.TestCase): + @parameterized.expand([(n, bias) for n in [4, 6, 10] for bias in [True, False]]) + def test_forward_agreement(self, n, bias): + layer = GivensRotationLayer(n, bias=bias) + naive_layer = NaiveGivensRotationLayer(n, n, bias=bias) + blocks = layer.blocks + U_naive = naive_layer._create_rotation_matrix(layer.angles, blocks) + U_parallel = layer._create_rotation_matrix() + + # Test that the matrices are close + self.assertTrue(torch.allclose(U_naive, U_parallel, atol=1e-6)) + + # Test orthogonality: + I = torch.eye(n) + UU_T = U_parallel @ U_parallel.T + self.assertTrue(torch.allclose(I, UU_T, atol=1e-6)) + + # Random input: + x = torch.randn((7, n)) # batch size 7 + y_naive = naive_layer(x, blocks) + y_parallel = layer(x) + self.assertTrue(torch.allclose(y_naive, y_parallel, atol=1e-6)) + + @parameterized.expand([(n, bias) for n in [4, 6, 10] for bias in [True, False]]) + def test_backward_agreement(self, n, bias): + layer = GivensRotationLayer(n, bias=bias) + naive_layer = NaiveGivensRotationLayer(n, n, bias=bias) + blocks = layer.blocks + + x = torch.randn((7, n)) # batch size 7 + + y_naive = naive_layer(x, blocks) + y_parallel = layer(x) + + # Define some dummy loss, e.g. closeness to the identity: + loss_naive = torch.sum((y_naive - x) ** 2) + loss_parallel = torch.sum((y_parallel - x) ** 2) + loss_naive.backward() + loss_parallel.backward() + grad_parallel = layer.angles.grad + grad_naive = naive_layer.angles.grad + + self.assertTrue(torch.allclose(grad_naive, grad_parallel, atol=1e-6)) class TestLinear(unittest.TestCase): @@ -83,7 +141,7 @@ def test_SkipLinear_different_dim(self): din = 33 dout = 99 model = SkipLinear(din, dout) - self.assertTrue(model_probably_good(model, (din,), (dout, ))) + self.assertTrue(model_probably_good(model, (din,), (dout,))) def test_SkipLinear_identity(self): # The skip linear function behaves as an identity function when the input dimension and @@ -93,7 +151,7 @@ def test_SkipLinear_identity(self): x = torch.randn((dim,)) y = model(x) self.assertTrue((x == y).all()) - self.assertTrue(model_probably_good(model, (dim,), (dim, ))) + self.assertTrue(model_probably_good(model, (dim,), (dim,))) if __name__ == "__main__": From 6b78c794d01afafe56c9f25daa196ec72d1bfcae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladimir=20Vargas=20Calder=C3=B3n?= Date: Thu, 18 Dec 2025 18:56:31 -0500 Subject: [PATCH 2/2] Add Givens orthogonal layer --- dwave/plugins/torch/nn/modules/__init__.py | 1 + dwave/plugins/torch/nn/modules/orthogonal.py | 188 +++++++++++++++++++ pyproject.toml | 1 + requirements.txt | 1 + tests/helper_models.py | 88 +++++++++ tests/test_nn.py | 98 ++++++++-- 6 files changed, 357 insertions(+), 20 deletions(-) create mode 100644 dwave/plugins/torch/nn/modules/orthogonal.py create mode 100644 tests/helper_models.py diff --git a/dwave/plugins/torch/nn/modules/__init__.py b/dwave/plugins/torch/nn/modules/__init__.py index 598a19c..1362c2f 100755 --- a/dwave/plugins/torch/nn/modules/__init__.py +++ b/dwave/plugins/torch/nn/modules/__init__.py @@ -14,4 +14,5 @@ # from dwave.plugins.torch.nn.modules.linear import * +from dwave.plugins.torch.nn.modules.orthogonal import * from dwave.plugins.torch.nn.modules.utils import * diff --git a/dwave/plugins/torch/nn/modules/orthogonal.py b/dwave/plugins/torch/nn/modules/orthogonal.py new file mode 100644 index 0000000..880b5fe --- /dev/null +++ b/dwave/plugins/torch/nn/modules/orthogonal.py @@ -0,0 +1,188 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + +import torch +import torch.nn as nn +from einops import einsum + +__all__ = ["get_blocks_edges", "GivensRotationLayer"] + + +def get_blocks_edges(n: int) -> list[list[tuple[int, int]]]: + """ + Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens + rotations. + + A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs + in the same block can be rotated in parallel since they commute. + + Args: + n (int): Dimension of the vector space onto which an orthogonal layer will be built. + + Returns: + list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations. + """ + + assert n % 2 == 0, "n must be even" # TODO: discuss odd case with Firas + + def circle_method(sequence): + seq_first_half = sequence[: len(sequence) // 2] + seq_second_half = sequence[len(sequence) // 2 :][::-1] + return list(zip(seq_first_half, seq_second_half)) + + blocks = [] + sequence = list(range(n)) + seqdeque = deque(sequence[1:]) + for _ in range(n - 1): + blocks.append(circle_method(sequence)) + seqdeque.rotate(1) + sequence[1:] = list(seqdeque) + return blocks + + +class RoundRobinGivens(torch.autograd.Function): + """ + Implements custom forward and backward passes to implement the parallel algorithms in + https://arxiv.org/abs/2106.00003 + """ + + @staticmethod + def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor: + """ + Creates a rotation matrix in n dimensions using parallel Givens transformations by blocks. + + Args: + ctx (context): Stores information for backward propagation. + angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations + between pairs of dimensions. + blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that + specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2 + pairs of independent rotations. + n (int): Dimension of the space. + + Returns: + torch.Tensor: The nxn rotation matrix. + """ + # Blocks is of shape (n_blocks, n/2, 2) containing indices for angles + # Within each block, each Givens rotation is commuting, so we can apply them in parallel + U = torch.eye(n, device=angles.device) + block_size = n // 2 + idx_block = torch.arange(block_size, device=angles.device) + for b, block in enumerate(blocks): + # angles is of shape (n_angles,) containing all angles for contiguous blocks. + angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,) + c = torch.cos(angles_in_block) + s = torch.sin(angles_in_block) + i_idx = block[:, 0] + j_idx = block[:, 1] + r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx] + r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx] + U[:, i_idx] = r_i + U[:, j_idx] = r_j + ctx.save_for_backward(angles, blocks, U) + return U + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + Computes the VJP needed for backward propagation. + + Args: + ctx (context): Contains information for backward propagation. + grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss + with respect to the output of the forward pass, i.e., dL/dU. + + Returns: + torch.Tensor: The gradient of the loss with respect to the input angles. + """ + angles, blocks, Ufwd_saved = ctx.saved_tensors + Ufwd = Ufwd_saved.clone() + M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n) + n = M.size(1) + block_size = n // 2 + A = torch.zeros((block_size, n), device=grad_output.device) + grad_theta = torch.zeros_like(angles) + idx_block = torch.arange(block_size, device=grad_output.device) + for b, block in enumerate(blocks): + i_idx = block[:, 0] + j_idx = block[:, 1] + angles_in_block = angles[idx_block + b * block_size] # shape (n/2,) + c = torch.cos(angles_in_block) + s = torch.sin(angles_in_block) + r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx] + r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx] + Ufwd[i_idx] = r_i + Ufwd[j_idx] = r_j + r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx] + r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx] + M[:, i_idx] = r_i + M[:, j_idx] = r_j + A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx] + grad_theta[idx_block + b * block_size] = A.sum(dim=1) + return grad_theta, None, None + + +class GivensRotationLayer(nn.Module): + """ + An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in a + round-robin fashion. + + Angles are arranged into blocks, where each block references rotations that can be applied in + parallel because these rotations commute. + + Args: + n (int): Dimension of the input and output space. + bias (bool): If True, adds a learnable bias to the output. Default: True. + """ + + def __init__(self, n: int, bias: bool = True): + super().__init__() + assert n % 2 == 0, "n must be even" # TODO: discuss odd case with Firas + self.n = n + self.n_angles = n * (n - 1) // 2 + self.angles = nn.Parameter(torch.randn(self.n_angles)) + blocks_edges = get_blocks_edges(n) + self.register_buffer( + "blocks", + torch.tensor(blocks_edges, dtype=torch.long), + ) + if bias: + self.bias = nn.Parameter(torch.zeros(n)) + else: + self.register_parameter("bias", None) + + def _create_rotation_matrix(self) -> torch.Tensor: + """ + Computes the Givens rotation matrix. + """ + U = RoundRobinGivens.apply(self.angles, self.blocks, self.n) + return U + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies the Givens rotation to the input tensor ``x``. + + Args: + x (torch.Tensor): Input tensor of shape (..., n). + + Returns: + torch.Tensor: Rotated tensor of shape (..., n). + """ + U = self._create_rotation_matrix() + rotated_x = einsum(x, U, "... i, o i -> ... o") + if self.bias is not None: + rotated_x = rotated_x + self.bias + return rotated_x diff --git a/pyproject.toml b/pyproject.toml index 74f1d43..ccddddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "dimod", "dwave-system", "dwave-hybrid", + "einops", ] [project.readme] diff --git a/requirements.txt b/requirements.txt index 1583954..ae91bed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch==2.9.1 dimod==0.12.18 dwave-system==1.28.0 dwave-hybrid==0.6.13 +einops==0.8.1 # Development requirements reno==4.1.0 diff --git a/tests/helper_models.py b/tests/helper_models.py new file mode 100644 index 0000000..b8a9104 --- /dev/null +++ b/tests/helper_models.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from einops import einsum + + +class NaiveGivensRotationLayer(nn.Module): + """ + Naive implementation of a Givens rotation layer. + + Sequentially applies all Givens rotations to implement an orthogonal transformation in an order + provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains + pairs of indices such that no index appears more than once in a block. However, this + implementation does not rely on that assumption, so that indeces can appear multiple times in a + block; however, all pairs of indices must appear exactly once in the entire blocks tensor. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): If True, adds a learnable bias to the output. Default: True. + + Note: + This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to + out_features. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + assert in_features == out_features, ( + "This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to " + "out_features." + ) + self.n = in_features + self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None): + """ + Creates the rotation matrix from the Givens angles by applying the Givens rotations in order + and sequentially, as specified by blocks. + + Args: + angles (torch.Tensor): Givens rotation angles. + blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If + None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create + the blocks. Defaults to None. + + Returns: + torch.Tensor: Rotation matrix. + """ + block_size = self.n // 2 + if blocks is None: + # Create dummy blocks from triu indices: + triu_indices = torch.triu_indices(self.n, self.n, offset=1) + blocks = triu_indices.t().view(-1, block_size, 2) + U = torch.eye(self.n) + for b, block in enumerate(blocks): + for k in range(block_size): + i = block[k, 0].item() + j = block[k, 1].item() + angle = angles[b * block_size + k] + c = torch.cos(angle) + s = torch.sin(angle) + # Need to clone because of pytorch. (This wouldn't happen in JAX) + r_i = c * U[:, i].clone() + s * U[:, j].clone() + r_j = -s * U[:, i].clone() + c * U[:, j].clone() + U[:, i] = r_i + U[:, j] = r_j + return U + + def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor: + """ + Applies the Givens rotation to the input tensor ``x``. + + Args: + x (torch.Tensor): Input tensor of shape (..., n). + blocks (torch.Tensor): Blocks specifying the order of rotations. + + Returns: + torch.Tensor: Rotated tensor of shape (..., n). + """ + W = self._create_rotation_matrix(self.angles, blocks) + x = einsum(x, W, "... i, o i -> ... o") + if self.bias is not None: + x = x + self.bias + return x diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..7cdf6cc 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -3,37 +3,46 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn import LinearBlock, SkipLinear, store_config +from dwave.plugins.torch.nn import GivensRotationLayer, LinearBlock, SkipLinear, store_config from tests.helper_functions import model_probably_good +from tests.helper_models import NaiveGivensRotationLayer class TestUtils(unittest.TestCase): def test_store_config(self): with self.subTest("Simple case"): + class MyModel(torch.nn.Module): @store_config - def __init__(self, a, b=1, *, x=4, y='hello'): + def __init__(self, a, b=1, *, x=4, y="hello"): super().__init__() model = MyModel(a=123, x=5) - self.assertDictEqual(dict(model.config), - {"a": 123, "b": 1, "x": 5, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), + {"a": 123, "b": 1, "x": 5, "y": "hello", "module_name": "MyModel"}, + ) model = MyModel(456) - self.assertDictEqual(dict(model.config), - {"a": 456, "b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), + {"a": 456, "b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}, + ) with self.subTest("Case with default args"): + class MyModel(torch.nn.Module): @store_config - def __init__(self, b=1, x=4, y='hello'): + def __init__(self, b=1, x=4, y="hello"): super().__init__() model = MyModel() - self.assertDictEqual(dict(model.config), - {"b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), {"b": 1, "x": 4, "y": "hello", "module_name": "MyModel"} + ) with self.subTest("Empty config case failed."): + class MyModel(torch.nn.Module): @store_config def __init__(self): @@ -45,7 +54,7 @@ def __init__(self): def test_store_config_nested(self): class InnerModel(torch.nn.Module): @store_config - def __init__(self, a, b=1, *, x=4, y='hello'): + def __init__(self, a, b=1, *, x=4, y="hello"): super().__init__() class OuterModel(torch.nn.Module): @@ -56,14 +65,63 @@ def __init__(self, module_1, module_2=None): module_1 = InnerModel(a=123, x=5) module_2 = InnerModel(a="second", y="lol") model = OuterModel(module_1, module_2) - self.assertDictEqual(dict(model.config), - {"module_1": module_1.config, - "module_2": module_2.config, - "module_name": "OuterModel"}) - self.assertDictEqual(dict(model.config["module_1"]), - dict(a=123, b=1, x=5, y="hello", module_name="InnerModel")) - self.assertDictEqual(dict(model.config["module_2"]), - dict(a="second", b=1, x=4, y="lol", module_name="InnerModel")) + self.assertDictEqual( + dict(model.config), + {"module_1": module_1.config, "module_2": module_2.config, "module_name": "OuterModel"}, + ) + self.assertDictEqual( + dict(model.config["module_1"]), + dict(a=123, b=1, x=5, y="hello", module_name="InnerModel"), + ) + self.assertDictEqual( + dict(model.config["module_2"]), + dict(a="second", b=1, x=4, y="lol", module_name="InnerModel"), + ) + + +class TestOrthogonal(unittest.TestCase): + @parameterized.expand([(n, bias) for n in [4, 6, 10] for bias in [True, False]]) + def test_forward_agreement(self, n, bias): + layer = GivensRotationLayer(n, bias=bias) + naive_layer = NaiveGivensRotationLayer(n, n, bias=bias) + blocks = layer.blocks + U_naive = naive_layer._create_rotation_matrix(layer.angles, blocks) + U_parallel = layer._create_rotation_matrix() + + # Test that the matrices are close + self.assertTrue(torch.allclose(U_naive, U_parallel, atol=1e-6)) + + # Test orthogonality: + I = torch.eye(n) + UU_T = U_parallel @ U_parallel.T + self.assertTrue(torch.allclose(I, UU_T, atol=1e-6)) + + # Random input: + x = torch.randn((7, n)) # batch size 7 + y_naive = naive_layer(x, blocks) + y_parallel = layer(x) + self.assertTrue(torch.allclose(y_naive, y_parallel, atol=1e-6)) + + @parameterized.expand([(n, bias) for n in [4, 6, 10] for bias in [True, False]]) + def test_backward_agreement(self, n, bias): + layer = GivensRotationLayer(n, bias=bias) + naive_layer = NaiveGivensRotationLayer(n, n, bias=bias) + blocks = layer.blocks + + x = torch.randn((7, n)) # batch size 7 + + y_naive = naive_layer(x, blocks) + y_parallel = layer(x) + + # Define some dummy loss, e.g. closeness to the identity: + loss_naive = torch.sum((y_naive - x) ** 2) + loss_parallel = torch.sum((y_parallel - x) ** 2) + loss_naive.backward() + loss_parallel.backward() + grad_parallel = layer.angles.grad + grad_naive = naive_layer.angles.grad + + self.assertTrue(torch.allclose(grad_naive, grad_parallel, atol=1e-6)) class TestLinear(unittest.TestCase): @@ -83,7 +141,7 @@ def test_SkipLinear_different_dim(self): din = 33 dout = 99 model = SkipLinear(din, dout) - self.assertTrue(model_probably_good(model, (din,), (dout, ))) + self.assertTrue(model_probably_good(model, (din,), (dout,))) def test_SkipLinear_identity(self): # The skip linear function behaves as an identity function when the input dimension and @@ -93,7 +151,7 @@ def test_SkipLinear_identity(self): x = torch.randn((dim,)) y = model(x) self.assertTrue((x == y).all()) - self.assertTrue(model_probably_good(model, (dim,), (dim, ))) + self.assertTrue(model_probably_good(model, (dim,), (dim,))) if __name__ == "__main__":