aitextgen
aitextgen copied to clipboard
Consider supporting JAX for faster TPU training
Per https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation , JAX is about twice as fast to train on a TPU than the corresponding PyTorch models, therefore it may be worthwhile to add support for it.
However it's dependent on when Hugging Face adds Trainer support as manually setting up the loops is not easy.