mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

potential bug for vision transformer classification head

Open HarborYuan opened this issue 2 years ago • 1 comments

Motivation

The Linear Layer is usually initialized by truncnormal init method in Vision Transformers, and is defined in the main model. The current implementation, if not carefully checked, may lead to a zero initialization on the head. I know the original JAX implementation of ViT uses such a zero initialization, but many other Vision Transoformers implemented with torch usually use a truncnormal init, which may not carefully considred in the config files in this repo (e.g. DeiT).

Modification

change default init_cfg to None of Vision Transformer cls head.

BC-breaking (Optional)

Use cases (Optional)

Checklist

Before PR:

  • [ ] Pre-commit or other linting tools are used to fix the potential lint issues.
  • [ ] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • [ ] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • [ ] The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • [ ] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • [ ] CLA has been signed and all committers have signed the CLA in this PR.

HarborYuan avatar May 09 '22 02:05 HarborYuan

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 86.75%. Comparing base (6beac50) to head (acb43e5). Report is 81 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #824      +/-   ##
==========================================
+ Coverage   86.68%   86.75%   +0.07%     
==========================================
  Files         128      130       +2     
  Lines        8255     8571     +316     
  Branches     1422     1478      +56     
==========================================
+ Hits         7156     7436     +280     
- Misses        885      911      +26     
- Partials      214      224      +10     
Flag Coverage Δ
unittests 86.69% <ø> (+0.07%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar May 09 '22 02:05 codecov[bot]