pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] Huge discrepancy between HuggingFace and timm in terms of the initialization of ViT

Open Phuoc-Hoan-Le opened this issue 3 years ago • 7 comments

I see a huge discrepancy between HuggingFace and timm in terms of the initialization of ViT. Timm's implementation uses trunc_normal whereas huggingface uses "module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)". I noticed this cause a huge drop in performance when training ViT models on imagenet with huggingface implementaion. I'm not sure if it's not just the initialization but also something more. Is it possible if one of you check and try to make the huggingface implementation as consistent as the timm's version? Thanks!

Phuoc-Hoan-Le avatar Sep 28 '22 18:09 Phuoc-Hoan-Le

@CharlesLeeeee that trunc normal at std=.02 isn't actually THAT different from normal, at least I doubt it's different enough to have a significant impact on training, the other timm init modes however like the jax and moco style are pretty different and can alter from-scratch training dynamics depending on hparams, dataset size, etc.

Are you sure training differences aren't due to other things?

While I'm a part of Hugging Face, my focus is still on timm, timm is considered more experimental, I add / tweak / improve when I see opportunity. HF Transformers is more stable and backwards compatibility is a high priority, I wouldn't be able to go in and change the init.

rwightman avatar Sep 28 '22 19:09 rwightman

@CharlesLeeeee that trunc normal at std=.02 isn't actually THAT different from normal, at least I doubt it's different enough to have a significant impact on training, the other timm init modes however like the jax and moco style are pretty different and can alter from-scratch training dynamics depending on hparams, dataset size, etc.

Are you sure training differences aren't due to other things?

While I'm a part of Hugging Face, my focus is still on timm, timm is considered more experimental, I add / tweak / improve when I see opportunity. HF Transformers is more stable and backwards compatibility is a high priority, I wouldn't be able to go in and change the init.

My training loop and hyperparameters is based on timms except for the model, I believe

Phuoc-Hoan-Le avatar Sep 28 '22 22:09 Phuoc-Hoan-Le

@CharlesLeeeee k, but say for hparams, the timm create_model passes some regularization params through to timm models that wouldn't work for transformers, so if you args have drop_path (stochastic depth), drop (dropout) enabled they wouldn't be used for transformers

Also, timm optimizer factory grabs some layer names to skip weight-decay on some special layers (like pos embed / token) via timm specific functions that you'd have to do manually for transformers vit for equivalence...

rwightman avatar Sep 28 '22 23:09 rwightman

@CharlesLeeeee k, but say for hparams, the timm create_model passes some regularization params through to timm models that wouldn't work for transformers, so if you args have drop_path (stochastic depth), drop (dropout) enabled they wouldn't be used for transformers

Also, timm optimizer factory grabs some layer names to skip weight-decay on some special layers (like pos embed / token) via timm specific functions that you'd have to do manually for transformers vit for equivalence...

I see. I did add stochastic depth to the huggingface model and made sure the argument drop will affect the dropout in the huggingface model. However, I was unaware that there is a weight-decay on some special layers like pos embed and cls_token. Thanks!

Is there anything else that you could list that a Huggingface user should be aware of when using the timm training pipeline/file with the huggingface model, instead of timm model when training? Thanks!

Phuoc-Hoan-Le avatar Sep 29 '22 01:09 Phuoc-Hoan-Le

@rwightman currently training loss is going nan after few epochs

Phuoc-Hoan-Le avatar Sep 29 '22 04:09 Phuoc-Hoan-Le

@CharlesLeeeee both are NaN? vit from scratch usually requires grad clipping + adamw. FYI if you can't resolve the differences for train you could always train with timm and remap the checkpoints to HF... they are compatible, just named differently and possibly a few extra/missing 1-dim in the embed tensors that can be squeezed/unsqueezed...

rwightman avatar Sep 29 '22 17:09 rwightman

@rwightman I am talking about the HuggingFace model and I am trying to train DeiT as in https://arxiv.org/abs/2012.12877 without the distillation process.

So I got the HuggingFace ViT model to train properly. The condition that I was training on was on:

torch==1.7.0 torchvision==0.8.1 timm==0.3.2 (for image processing, optimizer, and lr_scheduler purposes) -Removing head mask and removing anything relating to pruning heads in the HuggingFace ViT model class file -Using the exact same initialization as timm==0.3.2 rather than the one from hugginface -Removing HuggingFace ViT model inheritance to PretrainedModel -Added DropPath to the class.

Note: I am not sure which one or any of these directly fixed the problem, but it crashed or I get nan when using torch==1.11.0 and torchvision==0.12.0.

Phuoc-Hoan-Le avatar Oct 03 '22 15:10 Phuoc-Hoan-Le