mmpretrain
mmpretrain copied to clipboard
[Feature] MultiheadAttentionPooling
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
This PR is adding feature of multi head attention pooling neck. This feature is inspired by Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks.
Modification
The multi head attention pooling neck was implemented with maintaining the code style of the already existing hr_fuse.
BC-breaking (Optional)
Use cases (Optional)
The use cases of this feature is simply seen in test code. However, surely the usage of this function is presented below.
in_channels = (18, 32, 64, 128)
neck = HRFuseScales(in_channels=in_channels, out_channels=1024)
feat_size = 56
inputs = []
for in_channel in in_channels:
input_tensor = torch.rand(3, in_channel, feat_size, feat_size)
inputs.append(input_tensor)
feat_size = feat_size // 2
outs = neck(tuple(inputs))
assert isinstance(outs, tuple)
assert len(outs) == 1
assert outs[0].shape == (3, 1024, 7, 7)
Checklist
Before PR:
- [x] Pre-commit or other linting tools are used to fix the potential lint issues.
- [x] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
- [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
- [x] The documentation has been modified accordingly, like docstring or example tutorials.
After PR:
- [x] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
- [x] CLA has been signed and all committers have signed the CLA in this PR.
Thanks for your contribution and we appreciate it a lot. We will review it soon.
Codecov Report
Merging #939 (69bf870) into master (71ef7ba) will increase coverage by
0.87%. The diff coverage is100.00%.
@@ Coverage Diff @@
## master #939 +/- ##
==========================================
+ Coverage 84.57% 85.44% +0.87%
==========================================
Files 134 135 +1
Lines 8772 8810 +38
Branches 1516 1521 +5
==========================================
+ Hits 7419 7528 +109
+ Misses 1117 1058 -59
+ Partials 236 224 -12
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 85.39% <100.00%> (+0.81%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Impacted Files | Coverage Δ | |
|---|---|---|
| mmcls/models/necks/__init__.py | 100.00% <100.00%> (ø) |
|
| mmcls/models/necks/map.py | 100.00% <100.00%> (ø) |
|
| mmcls/core/visualization/image.py | 87.17% <0.00%> (+0.97%) |
:arrow_up: |
| mmcls/datasets/builder.py | 83.78% <0.00%> (+5.40%) |
:arrow_up: |
| mmcls/models/backbones/convmixer.py | 98.36% <0.00%> (+9.83%) |
:arrow_up: |
| mmcls/datasets/custom.py | 100.00% <0.00%> (+11.47%) |
:arrow_up: |
| mmcls/models/utils/helpers.py | 100.00% <0.00%> (+20.83%) |
:arrow_up: |
| mmcls/utils/setup_env.py | 95.45% <0.00%> (+22.72%) |
:arrow_up: |
| mmcls/utils/device.py | 100.00% <0.00%> (+27.27%) |
:arrow_up: |
| mmcls/models/backbones/convnext.py | 94.78% <0.00%> (+34.78%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 71ef7ba...69bf870. Read the comment docs.
Could you provide any config file which use this neck?
Could you provide any config file which use this neck?
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0,1,2,3,),
style='pytorch',
),
neck=dict(
type='MultiheadAttentionPooling',
in_channels=(256, 512, 1024, 2048),
num_heads=(8, 8, 8, 8),
out_channels=2048,
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
num_classes=1000,
reduction='mean',
loss_weight=1.0),
topk=(1, 5),
cal_acc=False))
or
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3,),
style='pytorch',
),
neck=dict(
type='MultiheadAttentionPooling',
in_channels=[2048],
num_heads=[8],
out_channels=2048,
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
num_classes=1000,
reduction='mean',
loss_weight=1.0),
topk=(1, 5),
cal_acc=False))
Hi @chagmgang !We are grateful for your efforts in helping improve mmpretrain open-source project during your personal time.
Welcome to join OpenMMLab Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA If you have a WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution❤