Skip to content

Commit

Permalink
docs for equiadapt/images/canonicalization_networks
Browse files Browse the repository at this point in the history
  • Loading branch information
sibasmarak committed Mar 13, 2024
1 parent 666ecd9 commit 5378926
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@


class CustomEquivariantNetwork(nn.Module):
"""
This class represents a custom equivariant network.
The network is equivariant to a specified group, which can be either the rotation group or the roto-reflection group. The network consists of a sequence of equivariant convolutional layers, each followed by a ReLU activation function.
Methods:
__init__: Initializes the CustomEquivariantNetwork instance.
forward: Performs a forward pass through the network.
"""

def __init__(
self,
in_shape: Tuple[int, int, int, int],
Expand All @@ -22,6 +32,18 @@ def __init__(
num_layers: int = 1,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initializes the CustomEquivariantNetwork instance.
Args:
in_shape (Tuple[int, int, int, int]): The shape of the input data.
out_channels (int): The number of output channels.
kernel_size (int): The size of the kernel in the convolutional layers.
group_type (str, optional): The type of group the network is equivariant to. Defaults to "rotation".
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
num_layers (int, optional): The number of layers in the network. Defaults to 1.
device (str, optional): The device to run the network on. Defaults to "cuda" if available, otherwise "cpu".
"""
super().__init__()

if group_type == "rotation":
Expand Down Expand Up @@ -57,8 +79,13 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x shape: (batch_size, in_channels, height, width)
:return: (batch_size, group_size)
Performs a forward pass through the network.
Args:
x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width).
Returns:
torch.Tensor: The output of the network. It has the shape (batch_size, group_size).
"""
feature_map = self.eqv_network(x)
group_activatiobs = torch.mean(feature_map, dim=(1, 3, 4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@


class RotationEquivariantConvLift(nn.Module):
"""
This class represents a rotation equivariant convolutional layer with lifting.
The layer is equivariant to a group of rotations. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
Methods:
__init__: Initializes the RotationEquivariantConvLift instance.
get_rotated_weights: Returns the weights of the layer after rotation.
forward: Performs a forward pass through the layer.
"""

def __init__(
self,
in_channels: int,
Expand All @@ -18,6 +29,19 @@ def __init__(
bias: bool = True,
device: str = "cuda",
):
"""
Initializes the RotationEquivariantConvLift instance.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the kernel.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
stride (int, optional): The stride of the convolution. Defaults to 1.
padding (int, optional): The padding of the convolution. Defaults to 0.
bias (bool, optional): Whether to include a bias term. Defaults to True.
device (str, optional): The device to run the layer on. Defaults to "cuda".
"""
super().__init__()
self.weights = nn.Parameter(
torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device)
Expand All @@ -38,6 +62,16 @@ def __init__(
def get_rotated_weights(
self, weights: torch.Tensor, num_rotations: int = 4
) -> torch.Tensor:
"""
Returns the weights of the layer after rotation.
Args:
weights (torch.Tensor): The weights of the layer.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
Returns:
torch.Tensor: The weights after rotation.
"""
device = weights.device
weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1)
rotated_weights = K.geometry.rotate(
Expand All @@ -57,8 +91,13 @@ def get_rotated_weights(

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x shape: (batch_size, in_channels, height, width)
:return: (batch_size, out_channels, num_rotations, height, width)
Performs a forward pass through the layer.
Args:
x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width).
Returns:
torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_rotations, height, width).
"""
batch_size = x.shape[0]
rotated_weights = self.get_rotated_weights(self.weights, self.num_rotations)
Expand All @@ -73,6 +112,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class RotoReflectionEquivariantConvLift(nn.Module):
"""
This class represents a roto-reflection equivariant convolutional layer with lifting.
The layer is equivariant to a group of rotations and reflections. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
Methods:
__init__: Initializes the RotoReflectionEquivariantConvLift instance.
get_rotoreflected_weights: Returns the weights of the layer after rotation, reflection, and permutation.
forward: Performs a forward pass through the layer.
"""

def __init__(
self,
in_channels: int,
Expand All @@ -84,6 +134,19 @@ def __init__(
bias: bool = True,
device: str = "cuda",
):
"""
Initializes the RotoReflectionEquivariantConvLift instance.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the kernel.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
stride (int, optional): The stride of the convolution. Defaults to 1.
padding (int, optional): The padding of the convolution. Defaults to 0.
bias (bool, optional): Whether to include a bias term. Defaults to True.
device (str, optional): The device to run the layer on. Defaults to "cuda".
"""
super().__init__()
num_group_elements = 2 * num_rotations
self.weights = nn.Parameter(
Expand All @@ -106,6 +169,16 @@ def __init__(
def get_rotoreflected_weights(
self, weights: torch.Tensor, num_rotations: int = 4
) -> torch.Tensor:
"""
Returns the weights of the layer after rotation and reflection.
Args:
weights (torch.Tensor): The weights of the layer.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
Returns:
torch.Tensor: The weights after rotation, reflection, and permutation.
"""
device = weights.device
weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1)
rotated_weights = K.geometry.rotate(
Expand All @@ -127,8 +200,13 @@ def get_rotoreflected_weights(

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x shape: (batch_size, in_channels, height, width)
:return: (batch_size, out_channels, num_group_elements, height, width)
Performs a forward pass through the layer.
Args:
x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width).
Returns:
torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_group_elements, height, width).
"""
batch_size = x.shape[0]
rotoreflected_weights = self.get_rotoreflected_weights(
Expand All @@ -149,6 +227,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class RotationEquivariantConv(nn.Module):
"""
This class represents a rotation equivariant convolutional layer.
The layer is equivariant to a group of rotations. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
Methods:
__init__: Initializes the RotationEquivariantConv instance.
get_rotated_permuted_weights: Returns the weights of the layer after rotation and permutation.
forward: Performs a forward pass through the layer.
"""

def __init__(
self,
in_channels: int,
Expand All @@ -160,6 +249,19 @@ def __init__(
bias: bool = True,
device: str = "cuda",
):
"""
Initializes the RotationEquivariantConv instance.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the kernel.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
stride (int, optional): The stride of the convolution. Defaults to 1.
padding (int, optional): The padding of the convolution. Defaults to 0.
bias (bool, optional): Whether to include a bias term. Defaults to True.
device (str, optional): The device to run the layer on. Defaults to "cuda".
"""
super().__init__()
self.weights = nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -196,6 +298,16 @@ def __init__(
def get_rotated_permuted_weights(
self, weights: torch.Tensor, num_rotations: int = 4
) -> torch.Tensor:
"""
Returns the weights of the layer after rotation and permutation.
Args:
weights (torch.Tensor): The weights of the layer.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
Returns:
torch.Tensor: The weights after rotation and permutation.
"""
weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1, 1)
permuted_weights = torch.gather(weights, 2, self.permute_indices_along_group)
rotated_permuted_weights = K.geometry.rotate(
Expand Down Expand Up @@ -223,8 +335,13 @@ def get_rotated_permuted_weights(

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x shape: (batch_size, in_channels, num_rotations, height, width)
:return: (batch_size, out_channels, num_rotations, height, width)
Performs a forward pass through the layer.
Args:
x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, num_rotations, height, width).
Returns:
torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_rotations, height, width).
"""
batch_size = x.shape[0]
x = x.flatten(1, 2)
Expand All @@ -245,6 +362,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class RotoReflectionEquivariantConv(nn.Module):
"""
This class represents a roto-reflection equivariant convolutional layer.
The layer is equivariant to a group of rotations and reflections. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
Methods:
__init__: Initializes the RotoReflectionEquivariantConv instance.
get_rotoreflected_permuted_weights: Returns the weights of the layer after rotation, reflection, and permutation.
forward: Performs a forward pass through the layer.
"""

def __init__(
self,
in_channels: int,
Expand All @@ -256,6 +384,19 @@ def __init__(
bias: bool = True,
device: str = "cuda",
):
"""
Initializes the RotoReflectionEquivariantConv instance.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the kernel.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
stride (int, optional): The stride of the convolution. Defaults to 1.
padding (int, optional): The padding of the convolution. Defaults to 0.
bias (bool, optional): Whether to include a bias term. Defaults to True.
device (str, optional): The device to run the layer on. Defaults to "cuda".
"""
super().__init__()
num_group_elements: int = 2 * num_rotations
self.weights = nn.Parameter(
Expand Down Expand Up @@ -320,6 +461,16 @@ def __init__(
def get_rotoreflected_permuted_weights(
self, weights: torch.Tensor, num_rotations: int = 4
) -> torch.Tensor:
"""
Returns the weights of the layer after rotation, reflection, and permutation.
Args:
weights (torch.Tensor): The weights of the layer.
num_rotations (int, optional): The number of rotations in the group. Defaults to 4.
Returns:
torch.Tensor: The weights after rotation, reflection, and permutation.
"""
weights = (
weights.flatten(0, 1)
.unsqueeze(0)
Expand Down Expand Up @@ -357,8 +508,13 @@ def get_rotoreflected_permuted_weights(

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x shape: (batch_size, in_channels, num_group_elements, height, width)
:return: (batch_size, out_channels, num_group_elements, height, width)
Performs a forward pass through the layer.
Args:
x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, num_group_elements, height, width).
Returns:
torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_group_elements, height, width).
"""
batch_size = x.shape[0]
x = x.flatten(1, 2)
Expand Down
Loading

0 comments on commit 5378926

Please sign in to comment.