-
Notifications
You must be signed in to change notification settings - Fork 2
/
convmixer.py
125 lines (93 loc) · 3.68 KB
/
convmixer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from functools import partial
import torch
import torch.nn as nn
from .ops import blocks
from .utils import export, config, load_from_local_or_url
from typing import Any
class Residual(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)
def forward(self, x):
return self[0](x) + x
@export
class ConvMixer(nn.Module):
@blocks.normalizer(position='after')
def __init__(
self,
in_channels: int = 3,
num_classes: int = 1000,
h=None,
depth=None,
kernel_size: int = 9,
patch_size: int = 7,
**kwargs: Any
):
super().__init__()
self.features = nn.Sequential(
blocks.Conv2dBlock(in_channels, h, patch_size, stride=patch_size),
*[nn.Sequential(
Residual(
blocks.Conv2dBlock(h, h, kernel_size, groups=h, padding='same')
),
blocks.Conv2d1x1Block(h, h)
) for _ in range(depth)]
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(h, num_classes)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def _conv_mixer(
h,
depth,
kernel_size: int = 9,
patch_size: int = 7,
pretrained: bool = False,
pth: str = None,
progress: bool = True,
**kwargs: Any
):
model = ConvMixer(h=h, depth=depth, kernel_size=kernel_size,
patch_size=patch_size, **kwargs)
if pretrained:
load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
return model
@export
@blocks.activation(nn.GELU)
def conv_mixer_1536_20_k9_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(1536, 20, 9, 7, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(nn.GELU)
def conv_mixer_1536_20_k3_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(1536, 20, 3, 7, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(nn.GELU)
def conv_mixer_1024_20_k9_p14(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(1024, 20, 9, 14, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(nn.GELU)
def conv_mixer_1024_16_k9_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(1024, 16, 9, 7, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(nn.GELU)
def conv_mixer_1024_12_k8_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(1024, 12, 8, 7, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(partial(nn.ReLU, inplace=True))
def conv_mixer_768_32_k7_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(768, 32, 7, 7, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(partial(nn.ReLU, inplace=True))
def conv_mixer_768_32_k3_p14(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(768, 32, 3, 14, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(nn.GELU)
def conv_mixer_512_16_k8_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(512, 16, 8, 7, pretrained, pth, progress, **kwargs)
@export
@blocks.activation(nn.GELU)
def conv_mixer_512_12_k8_p7(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _conv_mixer(512, 12, 8, 7, pretrained, pth, progress, **kwargs)