transformers
transformers copied to clipboard
TF: check embeddings range
What does this PR do?
Adds the same check that was recently added to TFBart, which asserts that the inputs are within the embedding input range, in all models with token embeddings. As a reminder: TF doesn't enforce this check by default on tf.gather
-dependent operations on GPU, returning a vector of 0.0
when out of bounds.
After this change, all test_embeddings_out_of_bounds_raise_exception
tests pass (36 failures in the previous scheduled CI).
To simplify the review, there are 3 models you should check. All others are copy/paste from these.
- Bert (Encoder)
- GPT2 (Decoder)
- Pegasus (Encoder-Decoder with
TFSharedEmbeddings
orTFWrappedEmbeddings
. Encoder-Decoder models that only use the embeddings at the decoder, like Speech2Text, also follow the same code pattern)
The documentation is not available anymore as the PR was closed or merged.
(cc @ydshieh -- this fixes a large number of scheduled CI failures)
@gante Is this PR ready to merge? I guess so, but would like to wait your confirmation (or better for you to merge).
@gante Is this PR ready to merge? I guess so, but would like to wait your confirmation (or better for you to merge).
@ydshieh It was ready -- merged now :D