transformers
transformers copied to clipboard
Fine tuning TensorFlow DeBERTa fails on TPU
System Info
Latest version of transformers, Colab TPU, tensorflow 2.
- Colab TPU
- transformers: 4.21.0
- tensorflow: 2.8.2 / 2.6.2
- Python 3.7
Who can help?
@LysandreJik, @Rocketknight1, @san
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
I am facing some issues while trying to fine-tune a TensorFlow DeBERTa model microsoft/deberta-v3-base on TPU.
I have created some Colab notebooks showing the errors. Note, the second and third notebooks already include some measures to circumvent previous errors.
- ValueError with partially known TensorShape with latest
take_along_axischange: FineTuning_TF_DeBERTa_TPU_1 - Output shape mismatch of branches with custom dropout: FineTuning_TF_DeBERTa_TPU_2
- XLA compilation error because of dynamic/computed tensor shapes: FineTuning_TF_DeBERTa_TPU_3
I have seen similar issues when using microsoft/deberta-base.
I believe the following issues are related:
- TF2 DeBERTaV2 runs super slow on TPUs #18239
- Debertav2 debertav3 TPU : socket closed #18276. From this I used the fix on
take_along_axis.
Thanks!
Expected behavior
Fine tuning is possible as it happens when using a GPU.
Hi @tmoroder 👋 Thank you for adding all that information to the issue <3
If I got it right, the second notebook replaces the take_along_axis function, and the third notebook also replaces the custom dropout. Still, there are XLA exceptions.
Before diving into debugging, two questions:
- Does it return the same error on a GPU?
- I see that you prepare a dataset with static batch size and that the input is padded. Do you think that there is any additional source of shape variability in the inputs? (I don't think so, but asking doesn't hurt :D )
Hi @gante.
If I got it right, the second notebook replaces the
take_along_axis function, and the third notebook also replaces the custom dropout. Still, there are XLA exceptions.
Correct. I think the XLA exceptions occur during gradient computation at these dynamic/computed tensor shape sizes.
The first collection seems to me being triggered within the TFDebertaV2DisentangledSelfAttention.disentangled_att_bias method, like at L735. I am not about other position TFDebertaV2DisentangledSelfAttention.call like L704.
- Does it return the same error on a GPU?
It runs on GPU without errors if I use transformers==4.20.1, see FineTuning_TF_DeBERTa _GPU. With version 4.21.0 I get the same error ValueError.
- I see that you prepare a dataset with static batch size and that the input is padded. Do you think that there is any additional source of shape variability in the inputs? (I don't think so, but asking doesn't hurt :D )
No further shape variability as fas as I can judge.
Hi @tmoroder, can you try on GPU with jit_compile=True in both 4.20 and 4.21? I believe the code had issues with XLA before 4.21, and TPU code is always compiled to XLA.
Interesting. Since transformers==4.20.1, there are only two DeBERTa PRs:
- https://github.com/huggingface/transformers/pull/17940 (should have no impact at all here)
- https://github.com/huggingface/transformers/pull/18256 (what should have been a TPU-friendly
take_along_axis)
As @Rocketknight1 said, that data would be interesting. If v4.21 works on GPU but not on TPU, we are up for an interesting challenge :D
Hi @tmoroder, can you try on GPU with jit_compile=True in both 4.20 and 4.21?
Using jit_compile=True while compiling the model gives an error for both 4.20.1 and 4.21, e.g., FineTuning_TF_DeBERTa _GPU_Tests for 4.21; with 4.20.1 it crashes in the last command.
As @Rocketknight1 said, that data would be interesting. If v4.21 works on GPU but not on TPU, we are up for an interesting challenge :D
Without jit_compile=True it also fails on GPU with 4.21; with 4.20.1 it works.
That makes sense - we made changes to the model to make it XLA-compatible in 4.21. XLA compatibility is necessary for TPU support, so the 4.20 model would never have run on TPU. However, we seem to have some other TPU-specific issues now - in my testing I was able to get DeBERTa to work with XLA on GPU in 4.21.
Weird!
During my TPU and GPU tests, i was using a custom training loop instead of keras's .fit(), which I'm not sure if it actually matters.
In my custom training code, I got deberta to train in an electra style training, with XLA enabled with jit_compile=True with non of the issues mentioned above.
I will be sharing my code asap once I finish the pretraining and validate the results. It is based on Nvidia BERT and Electra TF2 training code https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/
@tmoroder I can confirm that I can run your example with jit_compile=True (i.e. XLA compilation) on model.compile(), using a GPU, if the two changes you made in your third TPU notebook:
- replace
take_along_axisby thetf.einsumversion - replace dropout by the standard dropout
If XLA compilation works, then it should run on TPU. I noticed that in your notebook you were using TF 2.6, which may explain the XLA failure. Are you able to bump your TPU TF version (to as high as possible)?
Meanwhile I'm opening a PR to reflect those two changes :)
@gante
Thanks a lot for your effort. Maybe I am doing something wrong... but using the code from your pull request it now runs on GPU (with jit_compile=True as additional argument during model compilation), while it still fails on TPU (without using jit_compile=True as an argument). I am using TF 2.8.2 in both cases which is the current default in the Colab environment. On TPU it seems again to have errors on the tile operation.
- Working GPU version: FineTuning_TF_DeBERTa_Propsed_Fix_GPU
- Failing TPU version: FineTuning_TF_DeBERTa_Propsed_Fix_TPU
(linking issues -- the Tile issue is also present in the following unsolved issue: https://github.com/huggingface/transformers/issues/14058)
The cause is trivial (the multiple argument of tf.tile can't have dynamic shapes), but the fix may be not :D Will look into it
@tmoroder the dynamic shape in question is the batch size. I may be able to get an alternative to tf.tile, but I highly doubt that it will make a difference -- it will be a dynamic shape no matter how I turn it around, as it is not set.
Alternatively, could you try setting the batch_size argument in the Input layers? It will make that shape static, which should be enough to squash the problem you're seeing :)
@gante
Great, setting the batch_size works 🥳. I only had to make sure that it divides the strategy.num_replicas_in_sync, FineTuning_TF_DeBERTa_Working_Fix_TPU. Thanks a lot, I will test the procedure now on my real use case at hand.
Wooo nice! 🎊
I'm closing this issue since the problem seems to be solved for now. Feel free to reopen if you run into new related issues. Also, if you have the authorization to, please share TPU-related findings -- I'm sure they will be useful for other users!
@tmoroder Hey, can i ask about the training throughput/performance you got with the TPUs?
@WissamAntoun
Here some output that I get during the model.fit call. The model is very close to the one in the Colab notebooks, but the run is carried out on a Kaggle TPU.
Some further specification:
- model max length: 512
- batch size: 128
- 12800 training samples (or 100 steps per epoch)
- about 7500 validation samples
- smoothed cross-entropy loss
- accuray and cross-entropy metric
When calling model.fit the method prints, depending on the base model backbone the following times:
deberta-v3-base: 540s (632s first epoch)bert-base-uncased: 29s (115s first epoch)
Hope it helps!
Oh great! I mean not great in the sense that the model is super slow on TPUs, but great that model.fit and my custom training loop have the same issue. you are getting 512sentences*100batches/540s=~23sents/s, and I'm getting ~sents/s but for an electra style training.
Thank you for providing the numbers they really helped.