Skip to content

Commit 86a7302

Browse files
authored
Merge pull request #181 from murufeng/main
Add MobileViT
2 parents f4b0b14 + 89d3a04 commit 86a7302

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- [CrossFormer](#crossformer)
2020
- [RegionViT](#regionvit)
2121
- [NesT](#nest)
22+
- [MobileViT](#mobilevit)
2223
- [Masked Autoencoder](#masked-autoencoder)
2324
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
2425
- [Masked Patch Prediction](#masked-patch-prediction)
@@ -549,6 +550,31 @@ img = torch.randn(1, 3, 224, 224)
549550
pred = nest(img) # (1, 1000)
550551
```
551552

553+
## MobileViT
554+
555+
<img src="./images/mbvit.png" width="400px"></img>
556+
557+
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
558+
perspective for the global processing of information with transformers.
559+
560+
You can use it with the following code (ex. mobilevit_xs)
561+
562+
```
563+
import torch
564+
from vit_pytorch.mobile_vit import MobileViT
565+
566+
mbvit_xs = MobileViT(
567+
image_size=(256, 256),
568+
dims = [96, 120, 144],
569+
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
570+
num_classes = 1000
571+
)
572+
573+
img = torch.randn(1, 3, 256, 256)
574+
575+
pred = mbvit_xs(img) # (1, 1000)
576+
```
577+
552578
## Simple Masked Image Modeling
553579

554580
<img src="./images/simmim.png" width="400px"/>

images/mbvit.png

206 KB
Loading

vit_pytorch/mobile_vit.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""
2+
An implementation of MobileViT Model as defined in:
3+
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
4+
Arxiv: https://arxiv.org/abs/2110.02178
5+
Origin Code: https://github.com/murufeng/awesome_lightweight_networks
6+
"""
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
from einops import rearrange
12+
13+
def _make_divisible(v, divisor, min_value=None):
14+
15+
if min_value is None:
16+
min_value = divisor
17+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
18+
if new_v < 0.9 * v:
19+
new_v += divisor
20+
return new_v
21+
22+
23+
def Conv_BN_ReLU(inp, oup, kernel, stride=1):
24+
return nn.Sequential(
25+
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
26+
nn.BatchNorm2d(oup),
27+
nn.ReLU6(inplace=True)
28+
)
29+
30+
31+
def conv_1x1_bn(inp, oup):
32+
return nn.Sequential(
33+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
34+
nn.BatchNorm2d(oup),
35+
nn.ReLU6(inplace=True)
36+
)
37+
38+
class PreNorm(nn.Module):
39+
def __init__(self, dim, fn):
40+
super().__init__()
41+
self.norm = nn.LayerNorm(dim)
42+
self.fn = fn
43+
44+
def forward(self, x, **kwargs):
45+
return self.fn(self.norm(x), **kwargs)
46+
47+
class FeedForward(nn.Module):
48+
def __init__(self, dim, hidden_dim, dropout=0.):
49+
super().__init__()
50+
self.ffn = nn.Sequential(
51+
nn.Linear(dim, hidden_dim),
52+
nn.SiLU(),
53+
nn.Dropout(dropout),
54+
nn.Linear(hidden_dim, dim),
55+
nn.Dropout(dropout)
56+
)
57+
58+
def forward(self, x):
59+
return self.ffn(x)
60+
61+
62+
class Attention(nn.Module):
63+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
64+
super().__init__()
65+
inner_dim = dim_head * heads
66+
project_out = not (heads == 1 and dim_head == dim)
67+
68+
self.heads = heads
69+
self.scale = dim_head ** -0.5
70+
71+
self.attend = nn.Softmax(dim=-1)
72+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
73+
74+
self.to_out = nn.Sequential(
75+
nn.Linear(inner_dim, dim),
76+
nn.Dropout(dropout)
77+
) if project_out else nn.Identity()
78+
79+
def forward(self, x):
80+
qkv = self.to_qkv(x).chunk(3, dim=-1)
81+
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
82+
83+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
84+
attn = self.attend(dots)
85+
out = torch.matmul(attn, v)
86+
out = rearrange(out, 'b p h n d -> b p n (h d)')
87+
return self.to_out(out)
88+
89+
90+
class Transformer(nn.Module):
91+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
92+
super().__init__()
93+
self.layers = nn.ModuleList([])
94+
for _ in range(depth):
95+
self.layers.append(nn.ModuleList([
96+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
97+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
98+
]))
99+
def forward(self, x):
100+
for attn, ff in self.layers:
101+
x = attn(x) + x
102+
x = ff(x) + x
103+
return x
104+
105+
class MV2Block(nn.Module):
106+
def __init__(self, inp, oup, stride=1, expand_ratio=4):
107+
super(MV2Block, self).__init__()
108+
assert stride in [1, 2]
109+
110+
hidden_dim = round(inp * expand_ratio)
111+
self.identity = stride == 1 and inp == oup
112+
113+
if expand_ratio == 1:
114+
self.conv = nn.Sequential(
115+
# dw
116+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
117+
nn.BatchNorm2d(hidden_dim),
118+
nn.SiLU(),
119+
# pw-linear
120+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
121+
nn.BatchNorm2d(oup),
122+
)
123+
else:
124+
self.conv = nn.Sequential(
125+
# pw
126+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
127+
nn.BatchNorm2d(hidden_dim),
128+
nn.SiLU(),
129+
# dw
130+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
131+
nn.BatchNorm2d(hidden_dim),
132+
nn.SiLU(),
133+
# pw-linear
134+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
135+
nn.BatchNorm2d(oup),
136+
)
137+
138+
def forward(self, x):
139+
if self.identity:
140+
return x + self.conv(x)
141+
else:
142+
return self.conv(x)
143+
144+
class MobileViTBlock(nn.Module):
145+
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
146+
super().__init__()
147+
self.ph, self.pw = patch_size
148+
149+
self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
150+
self.conv2 = conv_1x1_bn(channel, dim)
151+
152+
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
153+
154+
self.conv3 = conv_1x1_bn(dim, channel)
155+
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
156+
157+
def forward(self, x):
158+
y = x.clone()
159+
160+
# Local representations
161+
x = self.conv1(x)
162+
x = self.conv2(x)
163+
164+
# Global representations
165+
_, _, h, w = x.shape
166+
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
167+
x = self.transformer(x)
168+
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
169+
pw=self.pw)
170+
171+
# Fusion
172+
x = self.conv3(x)
173+
x = torch.cat((x, y), 1)
174+
x = self.conv4(x)
175+
return x
176+
177+
178+
class MobileViT(nn.Module):
179+
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
180+
super().__init__()
181+
ih, iw = image_size
182+
ph, pw = patch_size
183+
assert ih % ph == 0 and iw % pw == 0
184+
185+
L = [2, 4, 3]
186+
187+
self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
188+
189+
self.mv2 = nn.ModuleList([])
190+
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
191+
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
192+
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
193+
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
194+
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
195+
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
196+
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
197+
198+
self.mvit = nn.ModuleList([])
199+
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
200+
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
201+
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
202+
203+
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
204+
205+
self.pool = nn.AvgPool2d(ih // 32, 1)
206+
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
207+
208+
def forward(self, x):
209+
x = self.conv1(x)
210+
x = self.mv2[0](x)
211+
212+
x = self.mv2[1](x)
213+
x = self.mv2[2](x)
214+
x = self.mv2[3](x)
215+
216+
x = self.mv2[4](x)
217+
x = self.mvit[0](x)
218+
219+
x = self.mv2[5](x)
220+
x = self.mvit[1](x)
221+
222+
x = self.mv2[6](x)
223+
x = self.mvit[2](x)
224+
x = self.conv2(x)
225+
226+
x = self.pool(x).view(-1, x.shape[1])
227+
x = self.fc(x)
228+
return x
229+

0 commit comments

Comments
 (0)