From ee9af93349b9f8ca714b14d9b4cdb8e28bd0be22 Mon Sep 17 00:00:00 2001 From: Ze Liu Date: Sun, 4 Jul 2021 10:18:27 +0800 Subject: [PATCH] Add Swin MLP: a hierarchical fully MLP architecture using shifted windows. (#91) --- README.md | 32 +- config.py | 12 + configs/swin_mlp_base_patch4_window7_224.yaml | 9 + .../swin_mlp_tiny_c12_patch4_window8_256.yaml | 11 + .../swin_mlp_tiny_c24_patch4_window8_256.yaml | 11 + .../swin_mlp_tiny_c6_patch4_window8_256.yaml | 11 + configs/swin_tiny_c24_patch4_window8_256.yaml | 11 + models/build.py | 16 + models/swin_mlp.py | 468 ++++++++++++++++++ 9 files changed, 573 insertions(+), 8 deletions(-) create mode 100644 configs/swin_mlp_base_patch4_window7_224.yaml create mode 100644 configs/swin_mlp_tiny_c12_patch4_window8_256.yaml create mode 100644 configs/swin_mlp_tiny_c24_patch4_window8_256.yaml create mode 100644 configs/swin_mlp_tiny_c6_patch4_window8_256.yaml create mode 100644 configs/swin_tiny_c24_patch4_window8_256.yaml create mode 100644 models/swin_mlp.py diff --git a/README.md b/README.md index 9c44865a..5d05d3c1 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,9 @@ This repo is the official implementation of ["Swin Transformer: Hierarchical Vis > **Video Swin Transformer**: See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). ## Updates +***07/03/2021*** +1. Add **Swin MLP**: a hierarchical fully MLP architecture using shifted windows. + ***06/25/2021*** 1. [Video Swin Transformer](https://arxiv.org/abs/2106.13230) is released at [Video-Swin-Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). `Video Swin Transformer` achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (`84.9` top-1 accuracy on Kinetics-400 and `86.1` top-1 accuracy on Kinetics-600 with `~20x` less pre-training data and `~3x` smaller model size) and temporal modeling (`69.6` top-1 accuracy on Something-Something v2). @@ -61,14 +64,27 @@ ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: | -| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) | -| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) | -| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) | -| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) | -| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) | -| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) | -| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) | -| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) | +| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin_tiny_patch4_window7_224.yaml) | +| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin_small_patch4_window7_224.yaml) | +| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin_base_patch4_window7_224.yaml) | +| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[test-config](configs/swin_base_patch4_window12_384.yaml) | +| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[test-config](configs/swin_base_patch4_window7_224.yaml) | +| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[test-config](configs/swin_base_patch4_window12_384.yaml) | +| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[test-config](configs/swin_large_patch4_window7_224.yaml) | +| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[test-config](configs/swin_large_patch4_window12_384.yaml) | + +**ImageNet-1K Pretrained Swin MLP Models** + +| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) | +| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) | +| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G | 231 | [timm](https://github.com/rwightman/pytorch-image-models) | +| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin_tiny_c24_patch4_window8_256.yaml) | +| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin_mlp_tiny_c24_patch4_window8_256.yaml) | +| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin_mlp_tiny_c12_patch4_window8_256.yaml) | +| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin_mlp_tiny_c6_patch4_window8_256.yaml) | +| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin_mlp_base_patch4_window7_224.yaml) | Note: access code for `baidu` is `swin`. diff --git a/config.py b/config.py index 5f150f84..d9c85c31 100644 --- a/config.py +++ b/config.py @@ -71,6 +71,18 @@ _C.MODEL.SWIN.APE = False _C.MODEL.SWIN.PATCH_NORM = True +# Swin MLP parameters +_C.MODEL.SWIN_MLP = CN() +_C.MODEL.SWIN_MLP.PATCH_SIZE = 4 +_C.MODEL.SWIN_MLP.IN_CHANS = 3 +_C.MODEL.SWIN_MLP.EMBED_DIM = 96 +_C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] +_C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] +_C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 +_C.MODEL.SWIN_MLP.MLP_RATIO = 4. +_C.MODEL.SWIN_MLP.APE = False +_C.MODEL.SWIN_MLP.PATCH_NORM = True + # ----------------------------------------------------------------------------- # Training settings # ----------------------------------------------------------------------------- diff --git a/configs/swin_mlp_base_patch4_window7_224.yaml b/configs/swin_mlp_base_patch4_window7_224.yaml new file mode 100644 index 00000000..01c48c95 --- /dev/null +++ b/configs/swin_mlp_base_patch4_window7_224.yaml @@ -0,0 +1,9 @@ +MODEL: + TYPE: swin_mlp + NAME: swin_mlp_base_patch4_window7_224 + DROP_PATH_RATE: 0.5 + SWIN_MLP: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 7 diff --git a/configs/swin_mlp_tiny_c12_patch4_window8_256.yaml b/configs/swin_mlp_tiny_c12_patch4_window8_256.yaml new file mode 100644 index 00000000..d6c4576d --- /dev/null +++ b/configs/swin_mlp_tiny_c12_patch4_window8_256.yaml @@ -0,0 +1,11 @@ +DATA: + IMG_SIZE: 256 +MODEL: + TYPE: swin_mlp + NAME: swin_mlp_tiny_c12_patch4_window8_256 + DROP_PATH_RATE: 0.2 + SWIN_MLP: + EMBED_DIM: 96 + DEPTHS: [ 2, 2, 6, 2 ] + NUM_HEADS: [ 8, 16, 32, 64 ] + WINDOW_SIZE: 8 \ No newline at end of file diff --git a/configs/swin_mlp_tiny_c24_patch4_window8_256.yaml b/configs/swin_mlp_tiny_c24_patch4_window8_256.yaml new file mode 100644 index 00000000..15552a0d --- /dev/null +++ b/configs/swin_mlp_tiny_c24_patch4_window8_256.yaml @@ -0,0 +1,11 @@ +DATA: + IMG_SIZE: 256 +MODEL: + TYPE: swin_mlp + NAME: swin_mlp_tiny_c24_patch4_window8_256 + DROP_PATH_RATE: 0.2 + SWIN_MLP: + EMBED_DIM: 96 + DEPTHS: [ 2, 2, 6, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 8 \ No newline at end of file diff --git a/configs/swin_mlp_tiny_c6_patch4_window8_256.yaml b/configs/swin_mlp_tiny_c6_patch4_window8_256.yaml new file mode 100644 index 00000000..533bd8f6 --- /dev/null +++ b/configs/swin_mlp_tiny_c6_patch4_window8_256.yaml @@ -0,0 +1,11 @@ +DATA: + IMG_SIZE: 256 +MODEL: + TYPE: swin_mlp + NAME: swin_mlp_tiny_c6_patch4_window8_256 + DROP_PATH_RATE: 0.2 + SWIN_MLP: + EMBED_DIM: 96 + DEPTHS: [ 2, 2, 6, 2 ] + NUM_HEADS: [ 16, 32, 64, 128 ] + WINDOW_SIZE: 8 \ No newline at end of file diff --git a/configs/swin_tiny_c24_patch4_window8_256.yaml b/configs/swin_tiny_c24_patch4_window8_256.yaml new file mode 100644 index 00000000..2c2e9f9d --- /dev/null +++ b/configs/swin_tiny_c24_patch4_window8_256.yaml @@ -0,0 +1,11 @@ +DATA: + IMG_SIZE: 256 +MODEL: + TYPE: swin + NAME: swin_tiny_c24_patch4_window8_256 + DROP_PATH_RATE: 0.2 + SWIN: + EMBED_DIM: 96 + DEPTHS: [ 2, 2, 6, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 8 \ No newline at end of file diff --git a/models/build.py b/models/build.py index 632a2459..8cefcf82 100644 --- a/models/build.py +++ b/models/build.py @@ -6,6 +6,7 @@ # -------------------------------------------------------- from .swin_transformer import SwinTransformer +from .swin_mlp import SwinMLP def build_model(config): @@ -27,6 +28,21 @@ def build_model(config): ape=config.MODEL.SWIN.APE, patch_norm=config.MODEL.SWIN.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) + elif model_type == 'swin_mlp': + model = SwinMLP(img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, + in_chans=config.MODEL.SWIN_MLP.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, + depths=config.MODEL.SWIN_MLP.DEPTHS, + num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, + window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWIN_MLP.APE, + patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT) else: raise NotImplementedError(f"Unkown model: {model_type}") diff --git a/models/swin_mlp.py b/models/swin_mlp.py new file mode 100644 index 00000000..115c43cd --- /dev/null +++ b/models/swin_mlp.py @@ -0,0 +1,468 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SwinMLPBlock(nn.Module): + r""" Swin MLP Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.padding = [self.window_size - self.shift_size, self.shift_size, + self.window_size - self.shift_size, self.shift_size] # P_l,P_r,P_t,P_b + + self.norm1 = norm_layer(dim) + # use group convolution to implement multi-head MLP + self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2, + self.num_heads * self.window_size ** 2, + kernel_size=1, + groups=self.num_heads) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # shift + if self.shift_size > 0: + P_l, P_r, P_t, P_b = self.padding + shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0) + else: + shifted_x = x + _, _H, _W, _ = shifted_x.shape + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # Window/Shifted-Window Spatial MLP + x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads) + x_windows_heads = x_windows_heads.transpose(1, 2) # nW*B, nH, window_size*window_size, C//nH + x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size, + C // self.num_heads) + spatial_mlp_windows = self.spatial_mlp(x_windows_heads) # nW*B, nH*window_size*window_size, C//nH + spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size, + C // self.num_heads).transpose(1, 2) + spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C) + + # merge windows + spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W) # B H' W' C + + # reverse shift + if self.shift_size > 0: + P_l, P_r, P_t, P_b = self.padding + x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous() + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + + # Window/Shifted-Window Spatial MLP + if self.shift_size > 0: + nW = (H / self.window_size + 1) * (W / self.window_size + 1) + else: + nW = H * W / self.window_size / self.window_size + flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin MLP layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., drop_path=0., + norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinMLPBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinMLP(nn.Module): + r""" Swin MLP + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin MLP layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + drop_rate (float): Dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, nn.Conv1d)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops