vit-pytorch icon indicating copy to clipboard operation
vit-pytorch copied to clipboard

Model doesn't converge

Open liberbey opened this issue 4 years ago • 23 comments

We are trying to apply this method on a medical dataset, and have about 70K images (224 res) for 5 classes. However, our training doesn't converge (we tried a range of learning rates e.g. 3e-3, 3e-4 etc.) however doesn't seem to work. Currently our model outputs 45% accuracy where the average accuracy for this dataset is around 85-90% (we trained for 100 epochs). Is there anything else we should tune?

Also, here is our configuration:

batch_size = 64
epochs = 400
lr = 3e-4
gamma = 0.7
seed = 42

efficient_transformer = Linformer(
    dim=128,
    seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    depth=4,
    heads=8,
    k=64
)

# Visual Transformer

model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=5,
    transformer=efficient_transformer,  # nn.Transformer(d_model=128, nhead=8),
    channels=1,
).to(device)

Thank you very much!

liberbey avatar Dec 20 '20 21:12 liberbey

@liberbey Hey Ahmet! One of pitfalls of transformers is having settings that result in the dimension per head to be too small. The dimension per head should be at least 32 and best at 64. It can be calculated as dim // heads, so in your case, the dimension of each head is 16. Try increasing the dimension to 256 and increasing the sequence length (decrease patch size to 16) I would be very surprised if it does not work

lucidrains avatar Dec 21 '20 03:12 lucidrains

@liberbey depth should be at a minimum of 6

lucidrains avatar Dec 21 '20 03:12 lucidrains

@lucidrains Hey Phil, thank you very much for your help! I'll try these parameters and post the results here.

liberbey avatar Dec 21 '20 16:12 liberbey

@lucidrains We have changed the parameters as:

efficient_transformer = Linformer(
    dim=256,
    seq_len=197, 
    depth=6,
    heads=8,
    k=64
)

# Visual Transformer

model = ViT(
    dim=256,
    image_size=224,
    patch_size=16,
    num_classes=5,
    transformer=efficient_transformer,
    channels=1,
).to(device)

But our model still does not converge. Here are the results:

acc

loss

Do you have any other suggestions? Thanks again!

liberbey avatar Dec 22 '20 15:12 liberbey

@liberbey I think the only option is to get a bunch of unlabelled images (in the millions) and do self-supervised learning with BYOL before fine-tuning on your dataset. Transformers only work well in the high data / compute regime

lucidrains avatar Dec 23 '20 04:12 lucidrains

@lucidrains Thanks again! We will try to find a larger dataset. By the way, these are validation results, not test results. So we wondered if there could be another problem about our approach. Because we were expecting that the test results would be bad due to not using pretrained model but not the validation set... Also, do you have any suggestions by the dramatic drop around 80th epoch?

liberbey avatar Dec 23 '20 20:12 liberbey

@lucidrains Thanks again! We will try to find a larger dataset. By the way, these are validation results, not test results. So we wondered if there could be another problem about our approach. Because we were expecting that the test results would be bad due to not using pretrained model but not the validation set... Also, do you have any suggestions by the dramatic drop around 80th epoch?

Did you use a special learning rate scheduler? My loss curve on my own dataset also shows an uncommon curve, check here. Seems that ViT is hard to train.

SuX97 avatar Dec 24 '20 06:12 SuX97

@SuX97 We didn’t use anything special for tuning the learning rate, however I am not sure if this repo is coming with a default scheduler @lucid

liberbey avatar Dec 24 '20 13:12 liberbey

@SuX97 @liberbey well, there's been a new development, you two should try https://github.com/lucidrains/vit-pytorch#distillation

lucidrains avatar Dec 24 '20 19:12 lucidrains

you'll both still need at least a million images... haha

lucidrains avatar Dec 24 '20 19:12 lucidrains

@liberbey hey you'll probably need a schedule with linear warmup for training any transformer, look here for more info

umbertov avatar Dec 24 '20 22:12 umbertov

@umbertov Thanks for the suggestion. Do you know if it is supported in this repo? @lucidrains

liberbey avatar Dec 26 '20 08:12 liberbey

@lucidrains I just want to make sure: can I first do BYOL, and then try Distillation on top of it using this repo?

liberbey avatar Dec 26 '20 08:12 liberbey

@liberbey sure! i think those two techniques are complementary

lucidrains avatar Dec 26 '20 17:12 lucidrains

Hello All,

Is there any updates regarding ViT convergence? as I am facing the same issue. And are the suggested papers help in tackling this issue?

Thanks in advance.

eslambakr avatar Jul 30 '21 23:07 eslambakr

Hi All,

Very good question. I have implemented 8 transformers and out of 7 VITs have a serious convergence problem. The dataset size is 97000

Regards, Khawar

khawar-islam avatar Jul 31 '21 05:07 khawar-islam

Hello @khawar512,

Could u please mention the survived model :D that works fine with u? If u could mention the other 7 models too, it will be great .

Thanks in advance.

eslambakr avatar Jul 31 '21 12:07 eslambakr

Hi @eslambakr

The best model is Praymid ViT Swin, CAIT, CrossViT, CeiT, ViT,NesT, PVT.

khawar-islam avatar Jul 31 '21 13:07 khawar-islam

Thank you very much @khawar512 . But I don't know what do u mean by Praymid ViT, is it differs than PVT. If yes can u please share a reference to the paper and an implementation for it if exists?

Besides, I will try "DeiT: Training data-efficient image transformers & distillation through attention" https://arxiv.org/abs/2012.12877 and share the updates here. As I believe it should overcome the convergence problem.

eslambakr avatar Jul 31 '21 13:07 eslambakr

@eslambakr

I am talking about below paper https://arxiv.org/abs/2102.12122

khawar-islam avatar Jul 31 '21 14:07 khawar-islam

Any update on this? I also have problems getting the ViT converge on a medical classification dataset (50K) that converges fine on CNNs.

hubtub2 avatar Jan 03 '23 23:01 hubtub2

I can confirm I have the same issue, and I'm not even using this repo. I'm using ContinuousTransformerWrapper from https://github.com/lucidrains/x-transformers together with an extra cls_token bolted on for classification. So this could be a general artefact of using transformers for classification. This is surprising to me. The reason why I want to use something like simple_vit_1d is that i'm classifying speech which has a lot of temporal content that a CNN will struggle to capture.

pfeatherstone avatar Mar 21 '24 08:03 pfeatherstone

same here. VIT is bad except for ImageNet. that's funny

sipie800 avatar Apr 02 '24 13:04 sipie800