pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] Script to convert weight from Jax to PyTorch

Open yazdanbakhsh opened this issue 2 years ago • 6 comments

Is your feature request related to a problem? Please describe. I am trying to create multiple checkpoints of ViT at different iterations. Are there any systematic way to perform such conversion?

Describe the solution you'd like I would like to be able to convert JAX ViT model to a PyTorch model, similar to this model (https://huggingface.co/google/vit-base-patch16-224)

Describe alternatives you've considered I have tried to start pre-training HF models on A100 but so far was not successful to reach to same accuracy.

yazdanbakhsh avatar Dec 23 '22 09:12 yazdanbakhsh

@yazdanbakhsh loading jax .npz checkpoints is integrated into the model for original Google jax implementations (big vision support being merged today)

  • https://github.com/google-research/vision_transformer
  • https://github.com/google-research/big_vision

However, I don't have support for the Hugging Face ViT models. Usually people go from timm -> Transformers, not the other way around. I honestly wouldn't recommend pretraining from scratch on the Transformers model, I don't think it's been well tested outside of pretrained. Other people have reported this. timm is better tested for training from random weights.

rwightman avatar Dec 23 '22 18:12 rwightman

@rwightman Thanks for the explanation. Which script do you suggest to start with for distributed training of ViT using timm?

yazdanbakhsh avatar Dec 23 '22 18:12 yazdanbakhsh

@yazdanbakhsh just dumped some hparams, but as usual, they need adapting to your scenario / specific network https://gist.github.com/rwightman/943c0fe59293b44024bbd2d5d23e6303

rwightman avatar Dec 23 '22 19:12 rwightman

@rwightman Thanks again for providing the details. It is much more robust for training vision models. I could finally reach an accuracy of 76% for ViT-B/16. With this training script. I also made a change to include cutout as part of the randaug to be more consistent with the BigVision repo.

Wondering if you have any other suggestions for the configs to change to reach to a reasonable accuracy as the HF model?

yazdanbakhsh avatar Dec 26 '22 20:12 yazdanbakhsh

The obvious changes that I want to make are as follows (based on bigvision)

  1. learning rate: 0.001
  2. weight decay: 0.0001
  3. warup: 15 epochs (which with my step 632 becomes around 10K steps)

yazdanbakhsh avatar Dec 26 '22 20:12 yazdanbakhsh

@rwightman do you have the yaml scripts for ResNet, MobileNet, and other models? Would it be possible to share them? I have tried to use the new JSD loss but it gives me error.

yazdanbakhsh avatar Dec 30 '22 03:12 yazdanbakhsh