99import torch .nn as nn
1010
1111from einops import rearrange
12+ from einops .layers .torch import Reduce
1213
1314def _make_divisible (v , divisor , min_value = None ):
14-
1515 if min_value is None :
1616 min_value = divisor
1717 new_v = max (min_value , int (v + divisor / 2 ) // divisor * divisor )
@@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
2020 return new_v
2121
2222
23- def Conv_BN_ReLU (inp , oup , kernel , stride = 1 ):
23+ def conv_bn_relu (inp , oup , kernel , stride = 1 ):
2424 return nn .Sequential (
2525 nn .Conv2d (inp , oup , kernel_size = kernel , stride = stride , padding = 1 , bias = False ),
2626 nn .BatchNorm2d (oup ),
@@ -63,8 +63,6 @@ class Attention(nn.Module):
6363 def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. ):
6464 super ().__init__ ()
6565 inner_dim = dim_head * heads
66- project_out = not (heads == 1 and dim_head == dim )
67-
6866 self .heads = heads
6967 self .scale = dim_head ** - 0.5
7068
@@ -74,7 +72,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
7472 self .to_out = nn .Sequential (
7573 nn .Linear (inner_dim , dim ),
7674 nn .Dropout (dropout )
77- ) if project_out else nn . Identity ()
75+ )
7876
7977 def forward (self , x ):
8078 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
@@ -96,6 +94,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
9694 PreNorm (dim , Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout )),
9795 PreNorm (dim , FeedForward (dim , mlp_dim , dropout = dropout ))
9896 ]))
97+
9998 def forward (self , x ):
10099 for attn , ff in self .layers :
101100 x = attn (x ) + x
@@ -136,23 +135,24 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):
136135 )
137136
138137 def forward (self , x ):
138+ out = self .conv (x )
139+
139140 if self .identity :
140- return x + self .conv (x )
141- else :
142- return self .conv (x )
141+ out = out + x
142+ return out
143143
144144class MobileViTBlock (nn .Module ):
145145 def __init__ (self , dim , depth , channel , kernel_size , patch_size , mlp_dim , dropout = 0. ):
146146 super ().__init__ ()
147147 self .ph , self .pw = patch_size
148148
149- self .conv1 = Conv_BN_ReLU (channel , channel , kernel_size )
149+ self .conv1 = conv_bn_relu (channel , channel , kernel_size )
150150 self .conv2 = conv_1x1_bn (channel , dim )
151151
152152 self .transformer = Transformer (dim , depth , 1 , 32 , mlp_dim , dropout )
153153
154154 self .conv3 = conv_1x1_bn (dim , channel )
155- self .conv4 = Conv_BN_ReLU (2 * channel , channel , kernel_size )
155+ self .conv4 = conv_bn_relu (2 * channel , channel , kernel_size )
156156
157157 def forward (self , x ):
158158 y = x .clone ()
@@ -165,8 +165,7 @@ def forward(self, x):
165165 _ , _ , h , w = x .shape
166166 x = rearrange (x , 'b d (h ph) (w pw) -> b (ph pw) (h w) d' , ph = self .ph , pw = self .pw )
167167 x = self .transformer (x )
168- x = rearrange (x , 'b (ph pw) (h w) d -> b d (h ph) (w pw)' , h = h // self .ph , w = w // self .pw , ph = self .ph ,
169- pw = self .pw )
168+ x = rearrange (x , 'b (ph pw) (h w) d -> b d (h ph) (w pw)' , h = h // self .ph , w = w // self .pw , ph = self .ph , pw = self .pw )
170169
171170 # Fusion
172171 x = self .conv3 (x )
@@ -176,54 +175,65 @@ def forward(self, x):
176175
177176
178177class MobileViT (nn .Module ):
179- def __init__ (self , image_size , dims , channels , num_classes , expansion = 4 , kernel_size = 3 , patch_size = (2 , 2 )):
178+ def __init__ (
179+ self ,
180+ image_size ,
181+ dims ,
182+ channels ,
183+ num_classes ,
184+ expansion = 4 ,
185+ kernel_size = 3 ,
186+ patch_size = (2 , 2 ),
187+ depths = (2 , 4 , 3 )
188+ ):
180189 super ().__init__ ()
190+ assert len (dims ) == 3 , 'dims must be a tuple of 3'
191+ assert len (depths ) == 3 , 'depths must be a tuple of 3'
192+
181193 ih , iw = image_size
182194 ph , pw = patch_size
183195 assert ih % ph == 0 and iw % pw == 0
184196
185- L = [2 , 4 , 3 ]
186-
187- self .conv1 = Conv_BN_ReLU (3 , channels [0 ], kernel = 3 , stride = 2 )
188-
189- self .mv2 = nn .ModuleList ([])
190- self .mv2 .append (MV2Block (channels [0 ], channels [1 ], 1 , expansion ))
191- self .mv2 .append (MV2Block (channels [1 ], channels [2 ], 2 , expansion ))
192- self .mv2 .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
193- self .mv2 .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
194- self .mv2 .append (MV2Block (channels [3 ], channels [4 ], 2 , expansion ))
195- self .mv2 .append (MV2Block (channels [5 ], channels [6 ], 2 , expansion ))
196- self .mv2 .append (MV2Block (channels [7 ], channels [8 ], 2 , expansion ))
197-
198- self .mvit = nn .ModuleList ([])
199- self .mvit .append (MobileViTBlock (dims [0 ], L [0 ], channels [5 ], kernel_size , patch_size , int (dims [0 ] * 2 )))
200- self .mvit .append (MobileViTBlock (dims [1 ], L [1 ], channels [7 ], kernel_size , patch_size , int (dims [1 ] * 4 )))
201- self .mvit .append (MobileViTBlock (dims [2 ], L [2 ], channels [9 ], kernel_size , patch_size , int (dims [2 ] * 4 )))
202-
203- self .conv2 = conv_1x1_bn (channels [- 2 ], channels [- 1 ])
204-
205- self .pool = nn .AvgPool2d (ih // 32 , 1 )
206- self .fc = nn .Linear (channels [- 1 ], num_classes , bias = False )
197+ init_dim , * _ , last_dim = channels
198+
199+ self .conv1 = conv_bn_relu (3 , init_dim , kernel = 3 , stride = 2 )
200+
201+ self .stem = nn .ModuleList ([])
202+ self .stem .append (MV2Block (channels [0 ], channels [1 ], 1 , expansion ))
203+ self .stem .append (MV2Block (channels [1 ], channels [2 ], 2 , expansion ))
204+ self .stem .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
205+ self .stem .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
206+
207+ self .trunk = nn .ModuleList ([])
208+ self .trunk .append (nn .ModuleList ([
209+ MV2Block (channels [3 ], channels [4 ], 2 , expansion ),
210+ MobileViTBlock (dims [0 ], depths [0 ], channels [5 ], kernel_size , patch_size , int (dims [0 ] * 2 ))
211+ ]))
212+
213+ self .trunk .append (nn .ModuleList ([
214+ MV2Block (channels [5 ], channels [6 ], 2 , expansion ),
215+ MobileViTBlock (dims [1 ], depths [1 ], channels [7 ], kernel_size , patch_size , int (dims [1 ] * 4 ))
216+ ]))
217+
218+ self .trunk .append (nn .ModuleList ([
219+ MV2Block (channels [7 ], channels [8 ], 2 , expansion ),
220+ MobileViTBlock (dims [2 ], depths [2 ], channels [9 ], kernel_size , patch_size , int (dims [2 ] * 4 ))
221+ ]))
222+
223+ self .to_logits = nn .Sequential (
224+ conv_1x1_bn (channels [- 2 ], last_dim ),
225+ Reduce ('b c h w -> b c' , 'mean' ),
226+ nn .Linear (channels [- 1 ], num_classes , bias = False )
227+ )
207228
208229 def forward (self , x ):
209230 x = self .conv1 (x )
210- x = self .mv2 [0 ](x )
211-
212- x = self .mv2 [1 ](x )
213- x = self .mv2 [2 ](x )
214- x = self .mv2 [3 ](x )
215231
216- x = self .mv2 [ 4 ]( x )
217- x = self . mvit [ 0 ] (x )
232+ for conv in self .stem :
233+ x = conv (x )
218234
219- x = self .mv2 [5 ](x )
220- x = self .mvit [1 ](x )
221-
222- x = self .mv2 [6 ](x )
223- x = self .mvit [2 ](x )
224- x = self .conv2 (x )
225-
226- x = self .pool (x ).view (- 1 , x .shape [1 ])
227- x = self .fc (x )
228- return x
235+ for conv , attn in self .trunk :
236+ x = conv (x )
237+ x = attn (x )
229238
239+ return self .to_logits (x )
0 commit comments