MLPMixer-jax2tf
MLPMixer-jax2tf copied to clipboard
This repository hosts code for converting the original MLP Mixer models (JAX) to TensorFlow.
MLPMixer-jax2tf
Example usage.
This repository hosts code for converting the original MLP-Mixer models [1] (JAX) to TensorFlow. The converted models are hosted on TensorFlow Hub and can be found here: https://tfhub.dev/sayakpaul/collections/mlp-mixer/1.
Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.
Several model variants are available:
SAM [2] pre-trained (these models were pre-trained on ImageNet-1k):
- B/16 (classification, feature-extractor)
- B/32 (classification, feature-extractor)
ImageNet-1k fine-tuned:
- B/16 (classification, feature-extractor)
- L/16 (classification, feature-extractor)
ImageNet-21k pre-trained:
- B/16 (classification, feature-extractor)
- L/16 (classification, feature-extractor)
For more details on the training protocols, please follow [1, 3].
The original model classes and weights [4] were converted using the jax2tf
tool [5]. For details on the conversion process,
please refer to the conversion.ipynb
notebook.
I independently validated two models on the ImageNet-1k validation set. The table below reports the top-1 accuracies along with their respective logs from tensorboard.dev.
Model | Top-1 Accuracy | tb.dev link |
---|---|---|
B-16 fine-tuned on ImageNet-1k |
75.31% | Link |
B-16 pre-trained on ImageNet-1k using SAM |
75.58% | Link |
Here is a tensorboard.dev run that logs fine-tuning results (using this model) for the Flowers dataset.
Other notebooks
-
classification.ipynb
: Shows how to load a Vision Transformer model from TensorFlow Hub and run image classification. -
fine-tune.ipynb
: Shows how to fine-tune a Vision Transformer model from TensorFlow Hub on thetf_flowers
dataset.
References
[1] MLP-Mixer: An all-MLP Architecture for Vision by Tolstikhin et al.
[2] Sharpness-Aware Minimization for Efficiently Improving Generalization by Foret et al.
[5] jax2tf tool
Acknowledgements
Thanks to the ML-GDE program for providing GCP Credit support that helped me execute the experiments for this project.