transformers
transformers copied to clipboard
TF: XLA-trainable DeBERTa v2
What does this PR do?
As discussed in https://github.com/huggingface/transformers/issues/18476 and https://github.com/huggingface/transformers/issues/18239, there are two problems while training DeBERTa v2 with TensorFlow:
-
TFDebertaV2StableDropout
doesn't work at training time (actually, its logic is only triggered at training time, so it doesn't work at all :D) - TF complains about unknown shapes in
take_along_axis
(forward and backward passes, when the batch dim isNone
)
This PR fixes both problems above :)
Problem 1. is got a straightforward fix. The gradient propagation code didn't have the right gradient shapes -- this PR simplifies and fixes it by moving all functions inside the special dropout class (compare to the original PT implementation here -- also notice how much more elegant TF's code is ;)).
Problem 2. is tricker. The exception gets fixed with the addition of a shape_list
, but the code is super slow on TPU. This PR adds an if/else pair of branches, one that is efficient on TPU, the other on GPU :)
⚠️ These exceptions were not caught because deberta v2 and v3 rely on special config options -- e.g. https://huggingface.co/microsoft/deberta-v3-base/blob/main/config.json#L14
How can we ensure we properly test these configurations?
The documentation is not available anymore as the PR was closed or merged.
@gante I think it's better to replace
flat_x = tf.reshape(x, (-1, x.shape[-1]))
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, shape_list(indices))
with
gathered = tf.gather(x,indices,batch_dims=2)
which gives the same numerical results and the same performance according to my tests
https://github.com/huggingface/transformers/issues/18239#issuecomment-1193126061
@WissamAntoun thank you for pointing it out, I completely missed it in the original thread! 🙏 Will make the change
EDIT: this change also makes it ~10% faster 👍