diff --git a/dwave/plugins/torch/nn/modules/__init__.py b/dwave/plugins/torch/nn/modules/__init__.py index 598a19c..1362c2f 100755 --- a/dwave/plugins/torch/nn/modules/__init__.py +++ b/dwave/plugins/torch/nn/modules/__init__.py @@ -14,4 +14,5 @@ # from dwave.plugins.torch.nn.modules.linear import * +from dwave.plugins.torch.nn.modules.orthogonal import * from dwave.plugins.torch.nn.modules.utils import * diff --git a/dwave/plugins/torch/nn/modules/orthogonal.py b/dwave/plugins/torch/nn/modules/orthogonal.py new file mode 100644 index 0000000..184889d --- /dev/null +++ b/dwave/plugins/torch/nn/modules/orthogonal.py @@ -0,0 +1,200 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + +import torch +import torch.nn as nn +from einops import einsum + +from dwave.plugins.torch.nn.modules.utils import store_config + +__all__ = ["GivensRotationLayer"] + + +def _get_blocks_edges(n: int) -> list[list[tuple[int, int]]]: + """Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens + rotations. + + A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs + in the same block can be rotated in parallel since they commute. + + Args: + n (int): Dimension of the vector space onto which an orthogonal layer will be built. + + Returns: + list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations. + + Note: + If n is odd, a dummy dimension is added to make it even. When using the resulting blocks to + build an orthogonal transformation, rotations involving the dummy dimension should be + ignored. + """ + if n % 2 != 0: + n += 1 # Add a dummy dimension for odd n + is_odd = True + else: + is_odd = False + + def circle_method(sequence): + seq_first_half = sequence[: len(sequence) // 2] + seq_second_half = sequence[len(sequence) // 2 :][::-1] + return list(zip(seq_first_half, seq_second_half)) + + blocks = [] + sequence = list(range(n)) + seqdeque = deque(sequence[1:]) + for _ in range(n - 1): + pairs = circle_method(sequence) + if is_odd: + # Remove pairs involving the dummy dimension: + pairs = [pair for pair in pairs if n - 1 not in pair] + blocks.append(pairs) + seqdeque.rotate(1) + sequence[1:] = list(seqdeque) + return blocks + + +class _RoundRobinGivens(torch.autograd.Function): + """Implements custom forward and backward passes to implement the parallel algorithms in + https://arxiv.org/abs/2106.00003 + """ + + @staticmethod + def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor: + """Creates a rotation matrix in n dimensions using parallel Givens transformations by + blocks. + + Args: + ctx (context): Stores information for backward propagation. + angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations + between pairs of dimensions. + blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices + that specify rotations between pairs of dimensions. Each of the ``n - 1`` blocks + contains ``n // 2`` pairs of independent rotations. + n (int): Dimension of the space. + + Returns: + torch.Tensor: The nxn rotation matrix. + """ + # Blocks is of shape (n_blocks, n/2, 2) containing indices for angles + # Within each block, each Givens rotation is commuting, so we can apply them in parallel + U = torch.eye(n, device=angles.device, dtype=angles.dtype) + block_size = n // 2 + idx_block = torch.arange(block_size, device=angles.device) + for b, block in enumerate(blocks): + # angles is of shape (n_angles,) containing all angles for contiguous blocks. + angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,) + c = torch.cos(angles_in_block) + s = torch.sin(angles_in_block) + i_idx = block[:, 0] + j_idx = block[:, 1] + r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx] + r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx] + U[:, i_idx] = r_i + U[:, j_idx] = r_j + ctx.save_for_backward(angles, blocks, U) + return U + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]: + """Computes the VJP needed for backward propagation. + + Args: + ctx (context): Contains information for backward propagation. + grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss + with respect to the output of the forward pass, i.e., dL/dU. + + Returns: + tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input + angles. No calculation of gradients with respect to blocks or n is needed (cf. + forward method), so None is returned for these. + """ + angles, blocks, Ufwd_saved = ctx.saved_tensors + Ufwd = Ufwd_saved.clone() + M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n) + n = M.size(1) + block_size = n // 2 + A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype) + grad_theta = torch.zeros_like(angles, dtype=angles.dtype) + idx_block = torch.arange(block_size, device=angles.device) + for b, block in enumerate(blocks): + i_idx = block[:, 0] + j_idx = block[:, 1] + angles_in_block = angles[idx_block + b * block_size] # shape (n/2,) + c = torch.cos(angles_in_block) + s = torch.sin(angles_in_block) + r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx] + r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx] + Ufwd[i_idx] = r_i + Ufwd[j_idx] = r_j + r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx] + r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx] + M[:, i_idx] = r_i + M[:, j_idx] = r_j + A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx] + grad_theta[idx_block + b * block_size] = A.sum(dim=1) + return grad_theta, None, None + + +class GivensRotationLayer(nn.Module): + """An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in + a round-robin fashion. + + Angles are arranged into blocks, where each block references rotations that can be applied in + parallel because these rotations commute. + + Args: + n (int): Dimension of the input and output space. Must be at least 2. + bias (bool): If True, adds a learnable bias to the output. Default: True. + """ + + @store_config + def __init__(self, n: int, bias: bool = True): + super().__init__() + if not isinstance(n, int) or n <= 1: + raise ValueError(f"n must be an integer greater than 1, {n} was passed") + if not isinstance(bias, bool): + raise ValueError(f"bias must be a boolean, {bias} was passed") + self.n = n + self.n_angles = n * (n - 1) // 2 + self.angles = nn.Parameter(torch.randn(self.n_angles)) + blocks_edges = _get_blocks_edges(n) + self.register_buffer( + "blocks", + torch.tensor(blocks_edges, dtype=torch.long), + ) + if bias: + self.bias = nn.Parameter(torch.zeros(n)) + else: + self.register_parameter("bias", None) + + def _create_rotation_matrix(self) -> torch.Tensor: + """Computes the Givens rotation matrix.""" + return _RoundRobinGivens.apply(self.angles, self.blocks, self.n) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the Givens rotation to the input tensor ``x``. + + Args: + x (torch.Tensor): Input tensor of shape ``(..., n)``. + + Returns: + torch.Tensor: Rotated tensor of shape ``(..., n)``. + """ + unitary = self._create_rotation_matrix() + rotated_x = einsum(x, unitary, "... i, o i -> ... o") + if self.bias is not None: + rotated_x += self.bias + return rotated_x diff --git a/pyproject.toml b/pyproject.toml index 74f1d43..ccddddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "dimod", "dwave-system", "dwave-hybrid", + "einops", ] [project.readme] diff --git a/requirements.txt b/requirements.txt index 1583954..ae91bed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch==2.9.1 dimod==0.12.18 dwave-system==1.28.0 dwave-hybrid==0.6.13 +einops==0.8.1 # Development requirements reno==4.1.0 diff --git a/tests/helper_models.py b/tests/helper_models.py new file mode 100644 index 0000000..9b7aaa0 --- /dev/null +++ b/tests/helper_models.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from einops import einsum + + +class NaiveGivensRotationLayer(nn.Module): + """Naive implementation of a Givens rotation layer. + + Sequentially applies all Givens rotations to implement an orthogonal transformation in an order + provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains + pairs of indices such that no index appears more than once in a block. However, this + implementation does not rely on that assumption, so that indeces can appear multiple times in a + block; however, all pairs of indices must appear exactly once in the entire blocks tensor. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): If True, adds a learnable bias to the output. Default: True. + + Note: + This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to + out_features. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + assert in_features == out_features, ( + "This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to " + "out_features." + ) + self.n = in_features + self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None): + """Creates the rotation matrix from the Givens angles by applying the Givens rotations in + order and sequentially, as specified by blocks. + + Args: + angles (torch.Tensor): Givens rotation angles. + blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If + None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create + the blocks. Defaults to None. + + Returns: + torch.Tensor: Rotation matrix. + """ + block_size = self.n // 2 + if blocks is None: + # Create dummy blocks from triu indices: + triu_indices = torch.triu_indices(self.n, self.n, offset=1) + blocks = triu_indices.t().view(-1, block_size, 2) + U = torch.eye(self.n, dtype=angles.dtype) + for b, block in enumerate(blocks): + for k in range(block_size): + i = block[k, 0].item() + j = block[k, 1].item() + angle = angles[b * block_size + k] + c = torch.cos(angle) + s = torch.sin(angle) + Ge = torch.eye(self.n, dtype=angles.dtype) + Ge[i, i] = c + Ge[j, j] = c + Ge[i, j] = -s + Ge[j, i] = s + # Explicit Givens rotation + U = U @ Ge + return U + + def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor: + """Applies the Givens rotation to the input tensor ``x``. + + Args: + x (torch.Tensor): Input tensor of shape (..., n). + blocks (torch.Tensor): Blocks specifying the order of rotations. + + Returns: + torch.Tensor: Rotated tensor of shape (..., n). + """ + W = self._create_rotation_matrix(self.angles, blocks) + x = einsum(x, W, "... i, o i -> ... o") + if self.bias is not None: + x = x + self.bias + return x diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..70d390c 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -3,37 +3,47 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn import LinearBlock, SkipLinear, store_config +from dwave.plugins.torch.nn import GivensRotationLayer, LinearBlock, SkipLinear, store_config +from dwave.plugins.torch.nn.modules.orthogonal import _get_blocks_edges from tests.helper_functions import model_probably_good +from tests.helper_models import NaiveGivensRotationLayer class TestUtils(unittest.TestCase): def test_store_config(self): with self.subTest("Simple case"): + class MyModel(torch.nn.Module): @store_config - def __init__(self, a, b=1, *, x=4, y='hello'): + def __init__(self, a, b=1, *, x=4, y="hello"): super().__init__() model = MyModel(a=123, x=5) - self.assertDictEqual(dict(model.config), - {"a": 123, "b": 1, "x": 5, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), + {"a": 123, "b": 1, "x": 5, "y": "hello", "module_name": "MyModel"}, + ) model = MyModel(456) - self.assertDictEqual(dict(model.config), - {"a": 456, "b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), + {"a": 456, "b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}, + ) with self.subTest("Case with default args"): + class MyModel(torch.nn.Module): @store_config - def __init__(self, b=1, x=4, y='hello'): + def __init__(self, b=1, x=4, y="hello"): super().__init__() model = MyModel() - self.assertDictEqual(dict(model.config), - {"b": 1, "x": 4, "y": "hello", "module_name": "MyModel"}) + self.assertDictEqual( + dict(model.config), {"b": 1, "x": 4, "y": "hello", "module_name": "MyModel"} + ) with self.subTest("Empty config case failed."): + class MyModel(torch.nn.Module): @store_config def __init__(self): @@ -45,7 +55,7 @@ def __init__(self): def test_store_config_nested(self): class InnerModel(torch.nn.Module): @store_config - def __init__(self, a, b=1, *, x=4, y='hello'): + def __init__(self, a, b=1, *, x=4, y="hello"): super().__init__() class OuterModel(torch.nn.Module): @@ -56,14 +66,119 @@ def __init__(self, module_1, module_2=None): module_1 = InnerModel(a=123, x=5) module_2 = InnerModel(a="second", y="lol") model = OuterModel(module_1, module_2) - self.assertDictEqual(dict(model.config), - {"module_1": module_1.config, - "module_2": module_2.config, - "module_name": "OuterModel"}) - self.assertDictEqual(dict(model.config["module_1"]), - dict(a=123, b=1, x=5, y="hello", module_name="InnerModel")) - self.assertDictEqual(dict(model.config["module_2"]), - dict(a="second", b=1, x=4, y="lol", module_name="InnerModel")) + self.assertDictEqual( + dict(model.config), + {"module_1": module_1.config, "module_2": module_2.config, "module_name": "OuterModel"}, + ) + self.assertDictEqual( + dict(model.config["module_1"]), + dict(a=123, b=1, x=5, y="hello", module_name="InnerModel"), + ) + self.assertDictEqual( + dict(model.config["module_2"]), + dict(a="second", b=1, x=4, y="lol", module_name="InnerModel"), + ) + + +class TestOrthogonal(unittest.TestCase): + + def test_bad_initialization_parameters(self): + with self.subTest("n less than 2"): + with self.assertRaises(ValueError): + GivensRotationLayer(1) + with self.assertRaises(ValueError): + GivensRotationLayer(0) + with self.assertRaises(ValueError): + GivensRotationLayer(-5) + with self.subTest("n not integer"): + with self.assertRaises(ValueError): + GivensRotationLayer(3.5) + with self.assertRaises(ValueError): + GivensRotationLayer("a string") + with self.subTest("bias not boolean"): + with self.assertRaises(ValueError): + GivensRotationLayer(5, bias="another string") + with self.assertRaises(ValueError): + GivensRotationLayer(5, bias=1) + + @parameterized.expand([9, 10]) + def test_get_blocks_edges(self, n): + blocks = _get_blocks_edges(n) + # `blocks` is a list of lists of pairs, i.e., a list of blocks. Each block must contain + # pairs of dimensions such that each dimension appears only once per block. + # Also, across all blocks, each pair of dimensions must appear exactly once. + appeared_pairs = set() + for block in blocks: + appeared_dims = set() + for i, j in block: + self.assertNotIn(i, appeared_dims) + self.assertNotIn(j, appeared_dims) + appeared_dims.add(i) + appeared_dims.add(j) + pair = (min(i, j), max(i, j)) + self.assertNotIn(pair, appeared_pairs) + appeared_pairs.add(pair) + # Check that all pairs appeared: + for i in range(n): + for j in range(i + 1, n): + pair = (i, j) + self.assertIn(pair, appeared_pairs) + + @parameterized.expand([(n, bias) for n in [9, 10] for bias in [True, False]]) + def test_GivensRotationLayer(self, n, bias): + din = n + dout = n + model = GivensRotationLayer(n, bias=bias) + self.assertTrue(model_probably_good(model, (din,), (dout,))) + + @parameterized.expand([(n, bias) for n in [9, 10] for bias in [True, False]]) + def test_forward_agreement(self, n, bias): + layer = GivensRotationLayer(n, bias=bias).double() + naive_layer = NaiveGivensRotationLayer(n, n, bias=bias).double() + blocks = layer.blocks + U_naive = naive_layer._create_rotation_matrix(layer.angles, blocks) + U_parallel = layer._create_rotation_matrix() + + # Test that the matrices are close + self.assertTrue(torch.allclose(U_naive, U_parallel, atol=1e-6)) + + # Test orthogonality: + I = torch.eye(n, dtype=U_parallel.dtype) + UU_T = U_parallel @ U_parallel.T + self.assertTrue(torch.allclose(I, UU_T, atol=1e-6)) + + # Random input: + with torch.no_grad(): + # forward pass will check consistency, so angles must be the same + naive_layer.angles.copy_(layer.angles) + x = torch.randn((7, n), dtype=U_parallel.dtype) # batch size 7 + y_naive = naive_layer(x, blocks) + y_parallel = layer(x) + self.assertTrue(torch.allclose(y_naive, y_parallel, atol=1e-6)) + + @parameterized.expand([(n, bias) for n in [9, 10] for bias in [True, False]]) + def test_backward_agreement(self, n, bias): + layer = GivensRotationLayer(n, bias=bias).double() + naive_layer = NaiveGivensRotationLayer(n, n, bias=bias).double() + blocks = layer.blocks + + with torch.no_grad(): + # forward and backward pass will check consistency, so angles must be the same + naive_layer.angles.copy_(layer.angles) + + x = torch.randn((7, n), dtype=layer.angles.dtype) # batch size 7 + + y_naive = naive_layer(x, blocks) + y_parallel = layer(x) + + # Define some dummy loss, e.g. closeness to the identity: + loss_naive = torch.sum((y_naive - x) ** 2) + loss_parallel = torch.sum((y_parallel - x) ** 2) + loss_naive.backward() + loss_parallel.backward() + grad_parallel = layer.angles.grad + grad_naive = naive_layer.angles.grad + self.assertTrue(torch.allclose(grad_naive, grad_parallel, atol=1e-6)) class TestLinear(unittest.TestCase): @@ -83,7 +198,7 @@ def test_SkipLinear_different_dim(self): din = 33 dout = 99 model = SkipLinear(din, dout) - self.assertTrue(model_probably_good(model, (din,), (dout, ))) + self.assertTrue(model_probably_good(model, (din,), (dout,))) def test_SkipLinear_identity(self): # The skip linear function behaves as an identity function when the input dimension and @@ -93,7 +208,7 @@ def test_SkipLinear_identity(self): x = torch.randn((dim,)) y = model(x) self.assertTrue((x == y).all()) - self.assertTrue(model_probably_good(model, (dim,), (dim, ))) + self.assertTrue(model_probably_good(model, (dim,), (dim,))) if __name__ == "__main__":