You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Current Status of Experiment for Local Self-Attention employed in MetaFormer
Introduction and Related Work
Some words of history of deep learning-based medical image processing
Era of convolution
New era introducing self-attention to computer vision and about their limitation (missing local inductive bias, O(n²) SegFormer dim trick, etc)
Best of both worlds via local self-attention?
Made attention more efficient by introducing locality → see FLOPs plots
Related work (to be concluded): SwinUnet, UXNet, Slide-Transformer, MedNext, ResNeXt, PVT, SegFormer, Spatial MLP ...
Alternative approach to reduce complexity: MAMBA (separates dimensions)
Some sentences about papers utilizing large receptive fields (DeepLab, PyDiNet) and their benefits in some tasks
Motivated by MetaFormer findings regarding token mixing in self-attention for computer vision
MetaFormer and especially PoolFormer demonstrated that attention itself is not important for good performance but token mixing (exchange of information) is
Architecture of MetaFormer enables the use of different token mixer
PoolFormer proves that some locality in feature aggregation is beneficial
Hybrid stages in MetaFormer have achieved the best Acc@1 on ImageNet within their experiments
→ learnable token mixer still seems to be beneficial
PoolFormer [pool, pool, pool, pool]: 77.2%
MetaFormer [pool, pool, attention, attention]: 81%
MetaFormer with depthwise convolutions: 78.1%
Motivation to employ local self-attention as token mixer
Contributions
Local self-attention as token mixer in MetaFormer
Comparison to other token mixers like pooling, attention, and depthwise convolutions
investigation in possible extensions of local self-attention in MetaFormer
slopes (bias for softmax based on distance of query and key tokens)
position encoding (2 layer MLP on Pytorch Coordinate Grid at each stage)
global register to encounter artifacts (see VISION TRANSFORMERS NEED REGISTERS)
use of CLS token (global anchor capturing global context)
Investigate into the importance of pretraining for local self-attention to catch up with the inductive bias of convs
Identify and prevent (kind of) training instabilities using local flexattention implementation
use classification as example for global prediction task. Five datasets: ImageWoof, 3 non-trivial MedMNIST, Graz fracture classification (TBA?)
use semantic segmentation as example for dense prediction task. Two 2D datasets: Graz bone segmentation, JSRT (add one more?)
employ local self-attention as conv replacement in nnUNet for 3d segmentation and demonstrate its performance on two datasets of the Medical Decathlon
Experiments
Impact of Different Extensions
Investigate into the importance of MetaFormer's patch embedding (employed at the beginning of each stage) and try alternatives
local self-attention (kernel size 5) as token mixer
trained from scratch on ImageWoof
† means NaN during backpropagation (training instability), max value before NaN is reported
Extension
Acc@1, AUC, F1
Conv (default)
0.6529, 0.892, 0.6539
none
0.3413†
Coord-MLP
0.2631†
Coord-MLP + Conv
0.6442,0.8895,0.6456
Conv (default) is the default patch embedding of MetaFormer. Considered as baseline.
none: no patch embedding. Replaced by bilinear interpolation and 1x1 conv (no norm or activation) at each stage
Coord-MLP: 2 layer MLP on Pytorch Coordinate Grid at the first MetaFormer block of each stage
Other extension to local self-attention (using default patch embedding and same setting as above):
Extension
Acc@1, AUC, F1
Baseline
0.6529, 0.892, 0.6539
Slopes
0.6512, 0.8942, 0.6508
CLS Token
0.5472, 0.8648, 0.5455
Global Register
0.6532, 0.8766, 0.6533
slopes: Using score_mode to modify attention score before softmax. Adding a bias (weighted by 0.1) based of the distance of query and key tokens.
Slopes are calculated for x and y directions separately. Each head considers an own slope (bias along +/- x and +/-y).
CLS Token: Could function as global anchor to capture global context. nn.Parameter of dim/channel of first patch embedding (64).
Get projected by linear layer to the channel dim of each stage.
Final classification is done on CLS token by model head (default is global avg pool over very last feature channel).
CLS Token: Using the MetaFormer channel MLP and norm layer.
Global Register: Inspired by VISION TRANSFORMERS NEED REGISTERS. Employ CLS token as single global register,
but the classification is done over the pooled final feature map (default MetaFormer implementation).
Thoughts on Results:
For me, it seems that no extension is really beneficial. But I would like to mention them in the paper, since we have
for every one a good initial motivation/hypothesis (not mentioned here). I will also include the difference on parameter count,
to estimate the worth of the extensions, especially the global register.
Rerun with larger kernel (7)?
Impact of Pretraining
Try to use features of pretrained pooling stages to overcome lack of inductive bias in local self-attention.
I sequentially increase the number of pretrained pooling stages starting from 0 (do not use any pretrained weights) to 4 (only pretrained stages).
At the first table, i leave the pretrained weights frozen. † again means NaN during backpropagation (training instability), max value before NaN is reported.
Blank cells are due to multirows in the table, so the content of above cell is used.
Dataset
TokenMixer
from scratch
stage 1
stage 2
stage 3
linear probe
ImageWoof
pooling
0.77
0.796
0.857
0.9245
0.94
locattn
0.74
0.654†
0.854
0.922
DermaMNIST
pooling
0.7502
0.671
0.74
0.749
0.61
locattn
0.7323
0.376†
0.78
0.75
Same setting, but now fine-tuning the pretrained weights of the pooling stages.
Dataset
TokenMixer
from scratch
stage 1
stage 2
stage 3
all fine-tune
ImageWoof
pooling
0.77
0.8026
0.8441
0.8766
0.8931
locattn
0.74
0.4107†
0.8000
0.8806
DermaMNIST
pooling
0.7502
0.7475
0.79521
0.8411
0.8294
locattn
0.7323
0.3617†
0.6617
0.8078
Thoughts on Results:
Can not really explain the performance of local self-attention on Derma when 2 pretrained pooling stages are frozen.
for ImageWoof, the use of as much frozen pretrained parameters as possible yileds the best performance. No surprise here.
for Derma as "unknown" domain, pooling always outperforms local self-attention (except for the case of 2 frozen stages).
stage 3 is not really an option since on that resolution 7², we employ full self-attention since there is no locality anymore when using kernel size >3.
take away: just use pooling since it does not add any additional parameters and is efficient? Paper done :D
Different Token Mixers for pretrained MetaFormer
First two stages of Metaformer are pretrained with pooling as token mixer and frozen. X² indicates the kernel size.
TokenMixer
OrganS
Pneumonia
Derma
ImageWoof
Pooling 3²
0.8793
0.983
0.7416
0.8566
Conv 3²
Conv 5²
0.8443
Depthwise Conv 5²
0.8430
PyDiConv
LocalAttention 5²
0.8747
0.9766
0.7831
0.8611
FullAttention
0.8584
for comparison, the results of a CNN (ResNet34) pretrained on ImageNet are included.
TokenMixer
OrganS
Pneumonia
Derma
ImageWoof
ResNet34 3² (pretrained)
0.9085
ResNet34 3² (scratch)
0.7433
ResNet34 5² (scratch)
0.7165
Thoughts on Results:
rerun with training from scratch to be comparable with ResNet34?
missing local self attention with kernel size 3² and 7²!!
include kernel size 7² in ResNet34 experiments
larger kernel in ResNet not to seems to help...
MetaFormer from scratch
Token Mixer
Kernel Size
ImageWoof
pooling
3
0.7908, 0.9495, 0.7907
5
0.7941, 0.9517, 0.7965
7
0.8044, 0.9539, 0.8059
conv
3
0.7869, 0.9483, 0.7891
5
0.754, 0.9353, 0.7567
7
0.7078, 0.9235, 0.709
sep_conv
3
0.7915, 0.9534, 0.7937
5
0.7694, 0.9412, 0.7713
7
0.7411, 0.9341, 0.7424
locAttn *
3
0.7156, 0.9066, 0.7153
5
0.7037, 0.8922, 0.7037
7
full_attn
-
(*) reduced lr to 1e-4 but doubled epochs due to training instabilities
Pretrained MetaFormer
Using of pretrained MetaFormer (PPAA). Reuse attention weights for local self attention.