-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
81 lines (70 loc) · 2.76 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from groups import GroupBase
class GConv3D(torch.nn.Module):
def __init__(self, group : GroupBase, in_group_dim, in_channels, out_channels, kernel_size, transposed=False, stride=1, padding=0):
"""
Group Equivariant Convolution Layer with 3D convolution.
Args:
group (GroupBase): Group object.
in_group_dim (int): Size of input group dimension (1 or group.group_dim).
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int or tuple): Size of the 3D convolutional kernel.
stride (int or tuple): Convolution stride.
padding (int or tuple): Padding size.
"""
super(GConv3D, self).__init__()
self.group = group
self.in_group_dim = in_group_dim
assert in_group_dim == 1 or in_group_dim == group.group_dim, "in_group_dim must be 1 or group_dim"
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.transposed = transposed
self.stride = stride
self.padding = padding
def _to_tuple(value):
if isinstance(value, int):
return (value, value, value)
return value
# Define a learnable kernel for one group transformation
self.kernel = torch.nn.Parameter(
torch.empty(out_channels, in_group_dim, in_channels, *_to_tuple(kernel_size), requires_grad=True)
)
# initialize
torch.nn.init.xavier_normal_(self.kernel)
def prepare_filters(self):
"""
Apply the group transformation to the filter.
kernel (torch.Tensor): shape [out_channels,in_group_dim,in_channels,N0,N1,N2].
Returns:
torch.Tensor: Transformed kernel. shape [out_group*out_channels,in_group*in_channels,N0,N1,N2].
"""
if self.in_group_dim == 1:
WN = self.group.get_Grotations(self.kernel)
elif self.in_group_dim == self.group.group_dim:
WN = self.group.get_Grotations_permutations(self.kernel)
else:
raise ValueError("in_group_dim must be 1 or group_dim")
# a list of [out_channels,in_group_dim,in_channels,N0,N1,N2]
WN = torch.cat(WN, dim=0)
WN = torch.flatten(WN, start_dim=1, end_dim=2)
return WN
def forward(self, x) -> torch.Tensor:
"""
Forward pass of the GConv layer.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, in_group_dim,in_channels, N0,N1,N2].
Returns:
torch.Tensor: Output tensor after group equivariant convolution.
[batch_size, out_group,out_channels, N0,N1,N2].
"""
batch_size = x.size(0)
xN = torch.flatten(x, start_dim=1, end_dim=2)
WN = self.prepare_filters()
if self.transposed:
yN = torch.nn.functional.conv_transpose3d(xN, WN.transpose(0, 1), stride=self.stride, padding=self.padding)
else:
yN = torch.nn.functional.conv3d(xN, WN, stride=self.stride, padding=self.padding)
y = yN.reshape(batch_size, self.group.group_dim, self.out_channels, *yN.size()[2:])
return y