Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dwave/plugins/torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
#

from dwave.plugins.torch.nn.modules.linear import *
from dwave.plugins.torch.nn.modules.orthogonal import *
from dwave.plugins.torch.nn.modules.utils import *
194 changes: 194 additions & 0 deletions dwave/plugins/torch/nn/modules/orthogonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque

import torch
import torch.nn as nn
from einops import einsum

from dwave.plugins.torch.nn.modules.utils import store_config

__all__ = ["get_blocks_edges", "GivensRotationLayer"]


def get_blocks_edges(n: int) -> list[list[tuple[int, int]]]:
"""Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens
rotations.

A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs
in the same block can be rotated in parallel since they commute.

Args:
n (int): Dimension of the vector space onto which an orthogonal layer will be built.

Returns:
list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations.

Note:
If n is odd, a dummy dimension is added to make it even. When using the resulting blocks to
build an orthogonal transformation, rotations involving the dummy dimension should be
ignored.
"""
if n % 2 != 0:
n += 1 # Add a dummy dimension for odd n
is_odd = True
else:
is_odd = False

def circle_method(sequence):
seq_first_half = sequence[: len(sequence) // 2]
seq_second_half = sequence[len(sequence) // 2 :][::-1]
return list(zip(seq_first_half, seq_second_half))

blocks = []
sequence = list(range(n))
seqdeque = deque(sequence[1:])
for _ in range(n - 1):
pairs = circle_method(sequence)
if is_odd:
# Remove pairs involving the dummy dimension:
pairs = [pair for pair in pairs if n - 1 not in pair]
blocks.append(pairs)
seqdeque.rotate(1)
sequence[1:] = list(seqdeque)
return blocks


class RoundRobinGivens(torch.autograd.Function):
"""Implements custom forward and backward passes to implement the parallel algorithms in
https://arxiv.org/abs/2106.00003
"""

@staticmethod
def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor:
"""Creates a rotation matrix in n dimensions using parallel Givens transformations by
blocks.

Args:
ctx (context): Stores information for backward propagation.
angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2
pairs of independent rotations.
n (int): Dimension of the space.

Returns:
torch.Tensor: The nxn rotation matrix.
"""
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
# Within each block, each Givens rotation is commuting, so we can apply them in parallel
U = torch.eye(n, device=angles.device, dtype=angles.dtype)
block_size = n // 2
idx_block = torch.arange(block_size, device=angles.device)
for b, block in enumerate(blocks):
# angles is of shape (n_angles,) containing all angles for contiguous blocks.
angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,)
c = torch.cos(angles_in_block)
s = torch.sin(angles_in_block)
i_idx = block[:, 0]
j_idx = block[:, 1]
r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx]
r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx]
U[:, i_idx] = r_i
U[:, j_idx] = r_j
ctx.save_for_backward(angles, blocks, U)
return U

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Computes the VJP needed for backward propagation.

Args:
ctx (context): Contains information for backward propagation.
grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss
with respect to the output of the forward pass, i.e., dL/dU.

Returns:
torch.Tensor: The gradient of the loss with respect to the input angles.
"""
angles, blocks, Ufwd_saved = ctx.saved_tensors
Ufwd = Ufwd_saved.clone()
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
n = M.size(1)
block_size = n // 2
A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype)
grad_theta = torch.zeros_like(angles, dtype=angles.dtype)
idx_block = torch.arange(block_size, device=angles.device)
for b, block in enumerate(blocks):
i_idx = block[:, 0]
j_idx = block[:, 1]
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
c = torch.cos(angles_in_block)
s = torch.sin(angles_in_block)
r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx]
r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx]
Ufwd[i_idx] = r_i
Ufwd[j_idx] = r_j
r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx]
r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx]
M[:, i_idx] = r_i
M[:, j_idx] = r_j
A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx]
grad_theta[idx_block + b * block_size] = A.sum(dim=1)
return grad_theta, None, None


class GivensRotationLayer(nn.Module):
"""An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in
a round-robin fashion.

Angles are arranged into blocks, where each block references rotations that can be applied in
parallel because these rotations commute.

Args:
n (int): Dimension of the input and output space.
bias (bool): If True, adds a learnable bias to the output. Default: True.
"""

@store_config
def __init__(self, n: int, bias: bool = True):
super().__init__()
self.n = n
self.n_angles = n * (n - 1) // 2
self.angles = nn.Parameter(torch.randn(self.n_angles))
blocks_edges = get_blocks_edges(n)
self.register_buffer(
"blocks",
torch.tensor(blocks_edges, dtype=torch.long),
)
if bias:
self.bias = nn.Parameter(torch.zeros(n))
else:
self.register_parameter("bias", None)

def _create_rotation_matrix(self) -> torch.Tensor:
"""Computes the Givens rotation matrix."""
return RoundRobinGivens.apply(self.angles, self.blocks, self.n)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the Givens rotation to the input tensor ``x``.

Args:
x (torch.Tensor): Input tensor of shape (..., n).

Returns:
torch.Tensor: Rotated tensor of shape (..., n).
"""
U = self._create_rotation_matrix()
rotated_x = einsum(x, U, "... i, o i -> ... o")
if self.bias is not None:
rotated_x = rotated_x + self.bias
return rotated_x
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"dimod",
"dwave-system",
"dwave-hybrid",
"einops",
]

[project.readme]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torch==2.9.1
dimod==0.12.18
dwave-system==1.28.0
dwave-hybrid==0.6.13
einops==0.8.1

# Development requirements
reno==4.1.0
87 changes: 87 additions & 0 deletions tests/helper_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn as nn
from einops import einsum


class NaiveGivensRotationLayer(nn.Module):
"""Naive implementation of a Givens rotation layer.

Sequentially applies all Givens rotations to implement an orthogonal transformation in an order
provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains
pairs of indices such that no index appears more than once in a block. However, this
implementation does not rely on that assumption, so that indeces can appear multiple times in a
block; however, all pairs of indices must appear exactly once in the entire blocks tensor.

Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): If True, adds a learnable bias to the output. Default: True.

Note:
This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to
out_features.
"""

def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
assert in_features == out_features, (
"This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to "
"out_features."
)
self.n = in_features
self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)

def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None):
"""Creates the rotation matrix from the Givens angles by applying the Givens rotations in
order and sequentially, as specified by blocks.

Args:
angles (torch.Tensor): Givens rotation angles.
blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If
None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create
the blocks. Defaults to None.

Returns:
torch.Tensor: Rotation matrix.
"""
block_size = self.n // 2
if blocks is None:
# Create dummy blocks from triu indices:
triu_indices = torch.triu_indices(self.n, self.n, offset=1)
blocks = triu_indices.t().view(-1, block_size, 2)
U = torch.eye(self.n, dtype=angles.dtype)
for b, block in enumerate(blocks):
for k in range(block_size):
i = block[k, 0].item()
j = block[k, 1].item()
angle = angles[b * block_size + k]
c = torch.cos(angle)
s = torch.sin(angle)
Ge = torch.eye(self.n, dtype=angles.dtype)
Ge[i, i] = c
Ge[j, j] = c
Ge[i, j] = -s
Ge[j, i] = s
# Explicit Givens rotation
U = U @ Ge
return U

def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor:
"""Applies the Givens rotation to the input tensor ``x``.

Args:
x (torch.Tensor): Input tensor of shape (..., n).
blocks (torch.Tensor): Blocks specifying the order of rotations.

Returns:
torch.Tensor: Rotated tensor of shape (..., n).
"""
W = self._create_rotation_matrix(self.angles, blocks)
x = einsum(x, W, "... i, o i -> ... o")
if self.bias is not None:
x = x + self.bias
return x
Loading