transformers icon indicating copy to clipboard operation
transformers copied to clipboard

TF: XLA-trainable DeBERTa v2

Open gante opened this issue 1 year ago • 2 comments

What does this PR do?

As discussed in https://github.com/huggingface/transformers/issues/18476 and https://github.com/huggingface/transformers/issues/18239, there are two problems while training DeBERTa v2 with TensorFlow:

  1. TFDebertaV2StableDropout doesn't work at training time (actually, its logic is only triggered at training time, so it doesn't work at all :D)
  2. TF complains about unknown shapes in take_along_axis (forward and backward passes, when the batch dim is None)

This PR fixes both problems above :)

Problem 1. is got a straightforward fix. The gradient propagation code didn't have the right gradient shapes -- this PR simplifies and fixes it by moving all functions inside the special dropout class (compare to the original PT implementation here -- also notice how much more elegant TF's code is ;)).

Problem 2. is tricker. The exception gets fixed with the addition of a shape_list, but the code is super slow on TPU. This PR adds an if/else pair of branches, one that is efficient on TPU, the other on GPU :)


⚠️ These exceptions were not caught because deberta v2 and v3 rely on special config options -- e.g. https://huggingface.co/microsoft/deberta-v3-base/blob/main/config.json#L14

How can we ensure we properly test these configurations?

gante avatar Aug 09 '22 15:08 gante

The documentation is not available anymore as the PR was closed or merged.

@gante I think it's better to replace

flat_x = tf.reshape(x, (-1, x.shape[-1]))
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, shape_list(indices))

with

gathered = tf.gather(x,indices,batch_dims=2)

which gives the same numerical results and the same performance according to my tests

https://github.com/huggingface/transformers/issues/18239#issuecomment-1193126061

WissamAntoun avatar Aug 10 '22 00:08 WissamAntoun

@WissamAntoun thank you for pointing it out, I completely missed it in the original thread! 🙏 Will make the change

EDIT: this change also makes it ~10% faster 👍

gante avatar Aug 10 '22 09:08 gante