mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

[Feature] Support ViT Anti Oversmoothing

Open okotaku opened this issue 2 years ago • 3 comments

Motivation

Support ViT Anti Oversmoothing.

Anti-Oversmoothing in Deep Vision Transformers via the Fourier Domain Analysis: From Theory to Practice official repo

Since imagenet training is difficult at hand, I have not been able to verify the accuracy. Also, I only support Swin and ViT, but can support other transformer models. Let me know if you need.

Result

CUB200 bs8×1GPU

model top1 acc
swin-l 91.6638
+ feat scale 91.81913
+ attn scale 91.56024
+ feat and attn scale 91.78461

Since I have not done imagenet pretraining, the accuracy is probably only for reference. If we start with imagenet training, the accuracy may be further improved, but the effect may be limited (accuracy + 0.XX).

Use cases (Optional)

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='SwinTransformer', arch='small', img_size=224,
        drop_path_rate=0.3, attn_scale=True),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=768,
        init_cfg=None,  # suppress the default init_cfg of LinearClsHead.
        loss=dict(
            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
        cal_acc=False),
    init_cfg=[
        dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
        dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
    ],
    train_cfg=dict(augments=[
        dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
        dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
    ]))
# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='SwinTransformer', arch='small', img_size=224,
        drop_path_rate=0.3, feat_scale=True),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=768,
        init_cfg=None,  # suppress the default init_cfg of LinearClsHead.
        loss=dict(
            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
        cal_acc=False),
    init_cfg=[
        dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
        dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
    ],
    train_cfg=dict(augments=[
        dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
        dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
    ]))

Checklist

Before PR:

  • [x] Pre-commit or other linting tools are used to fix the potential lint issues.
  • [x] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • [x] The documentation has been modified accordingly, like docstring or example tutorials.
  • [x] Check accuracy with CUB.

After PR:

  • [x] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • [x] CLA has been signed and all committers have signed the CLA in this PR.

okotaku avatar Mar 16 '22 01:03 okotaku

Codecov Report

Merging #731 (986ef1d) into dev (c1534f9) will increase coverage by 0.09%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##              dev     #731      +/-   ##
==========================================
+ Coverage   84.98%   85.07%   +0.09%     
==========================================
  Files         122      122              
  Lines        7573     7620      +47     
  Branches     1304     1312       +8     
==========================================
+ Hits         6436     6483      +47     
  Misses        945      945              
  Partials      192      192              
Flag Coverage Δ
unittests 85.00% <100.00%> (+0.09%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/models/backbones/deit.py 97.82% <ø> (ø)
mmcls/models/backbones/swin_transformer.py 92.50% <100.00%> (+0.52%) :arrow_up:
mmcls/models/backbones/vision_transformer.py 96.75% <100.00%> (+0.37%) :arrow_up:
mmcls/models/utils/attention.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update c1534f9...986ef1d. Read the comment docs.

codecov[bot] avatar Mar 16 '22 01:03 codecov[bot]

I am a little confused as to whether I should add this PR or not. It is a simple implementation, but I don't want to complicate the implementation by adding a less commonly used param. What do you think?

okotaku avatar Mar 20 '22 04:03 okotaku

I am a little confused as to whether I should add this PR or not. It is a simple implementation, but I don't want to complicate the implementation by adding a less commonly used param. What do you think?

It is a pretty good feature! We plan to support large-scale models in the next quarter, and this feature may be useful.

Ezra-Yu avatar Mar 21 '22 08:03 Ezra-Yu