mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

[Feature] MultiheadAttentionPooling

Open chagmgang opened this issue 3 years ago • 5 comments
trafficstars

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

This PR is adding feature of multi head attention pooling neck. This feature is inspired by Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks.

Modification

The multi head attention pooling neck was implemented with maintaining the code style of the already existing hr_fuse.

BC-breaking (Optional)

Use cases (Optional)

The use cases of this feature is simply seen in test code. However, surely the usage of this function is presented below.

in_channels = (18, 32, 64, 128)
neck = HRFuseScales(in_channels=in_channels, out_channels=1024)

feat_size = 56
inputs = []
for in_channel in in_channels:
    input_tensor = torch.rand(3, in_channel, feat_size, feat_size)
    inputs.append(input_tensor)
    feat_size = feat_size // 2

outs = neck(tuple(inputs))
assert isinstance(outs, tuple)
assert len(outs) == 1
assert outs[0].shape == (3, 1024, 7, 7)

Checklist

Before PR:

  • [x] Pre-commit or other linting tools are used to fix the potential lint issues.
  • [x] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • [x] The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • [x] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • [x] CLA has been signed and all committers have signed the CLA in this PR.

chagmgang avatar Jul 29 '22 02:07 chagmgang

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Jul 29 '22 02:07 CLAassistant

Thanks for your contribution and we appreciate it a lot. We will review it soon.

Ezra-Yu avatar Jul 29 '22 03:07 Ezra-Yu

Codecov Report

Merging #939 (69bf870) into master (71ef7ba) will increase coverage by 0.87%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #939      +/-   ##
==========================================
+ Coverage   84.57%   85.44%   +0.87%     
==========================================
  Files         134      135       +1     
  Lines        8772     8810      +38     
  Branches     1516     1521       +5     
==========================================
+ Hits         7419     7528     +109     
+ Misses       1117     1058      -59     
+ Partials      236      224      -12     
Flag Coverage Δ
unittests 85.39% <100.00%> (+0.81%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/models/necks/__init__.py 100.00% <100.00%> (ø)
mmcls/models/necks/map.py 100.00% <100.00%> (ø)
mmcls/core/visualization/image.py 87.17% <0.00%> (+0.97%) :arrow_up:
mmcls/datasets/builder.py 83.78% <0.00%> (+5.40%) :arrow_up:
mmcls/models/backbones/convmixer.py 98.36% <0.00%> (+9.83%) :arrow_up:
mmcls/datasets/custom.py 100.00% <0.00%> (+11.47%) :arrow_up:
mmcls/models/utils/helpers.py 100.00% <0.00%> (+20.83%) :arrow_up:
mmcls/utils/setup_env.py 95.45% <0.00%> (+22.72%) :arrow_up:
mmcls/utils/device.py 100.00% <0.00%> (+27.27%) :arrow_up:
mmcls/models/backbones/convnext.py 94.78% <0.00%> (+34.78%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 71ef7ba...69bf870. Read the comment docs.

codecov[bot] avatar Jul 29 '22 04:07 codecov[bot]

Could you provide any config file which use this neck?

mzr1996 avatar Aug 02 '22 07:08 mzr1996

Could you provide any config file which use this neck?

model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0,1,2,3,),
        style='pytorch',
    ),
    neck=dict(
        type='MultiheadAttentionPooling',
        in_channels=(256, 512, 1024, 2048),
        num_heads=(8, 8, 8, 8),
        out_channels=2048,
    ),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=2048,
        loss=dict(
            type='LabelSmoothLoss',
            label_smooth_val=0.1,
            num_classes=1000,
            reduction='mean',
            loss_weight=1.0),
        topk=(1, 5),
        cal_acc=False))

or

model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(3,),
        style='pytorch',
    ),
    neck=dict(
        type='MultiheadAttentionPooling',
        in_channels=[2048],
        num_heads=[8],
        out_channels=2048,
    ),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=2048,
        loss=dict(
            type='LabelSmoothLoss',
            label_smooth_val=0.1,
            num_classes=1000,
            reduction='mean',
            loss_weight=1.0),
        topk=(1, 5),
        cal_acc=False))

chagmgang avatar Aug 02 '22 11:08 chagmgang

Hi @chagmgang !We are grateful for your efforts in helping improve mmpretrain open-source project during your personal time.

Welcome to join OpenMMLab Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA If you have a WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)

Thank you again for your contribution❤

OpenMMLab-Assistant-004 avatar Apr 13 '23 01:04 OpenMMLab-Assistant-004