aitextgen icon indicating copy to clipboard operation
aitextgen copied to clipboard

Consider supporting JAX for faster TPU training

Open minimaxir opened this issue 4 years ago • 0 comments

Per https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation , JAX is about twice as fast to train on a TPU than the corresponding PyTorch models, therefore it may be worthwhile to add support for it.

However it's dependent on when Hugging Face adds Trainer support as manually setting up the loops is not easy.

minimaxir avatar Jul 24 '21 20:07 minimaxir