transformers icon indicating copy to clipboard operation
transformers copied to clipboard

TF: check embeddings range

Open gante opened this issue 1 year ago • 3 comments

What does this PR do?

Adds the same check that was recently added to TFBart, which asserts that the inputs are within the embedding input range, in all models with token embeddings. As a reminder: TF doesn't enforce this check by default on tf.gather-dependent operations on GPU, returning a vector of 0.0 when out of bounds.

After this change, all test_embeddings_out_of_bounds_raise_exception tests pass (36 failures in the previous scheduled CI).

To simplify the review, there are 3 models you should check. All others are copy/paste from these.

  1. Bert (Encoder)
  2. GPT2 (Decoder)
  3. Pegasus (Encoder-Decoder with TFSharedEmbeddings or TFWrappedEmbeddings. Encoder-Decoder models that only use the embeddings at the decoder, like Speech2Text, also follow the same code pattern)

gante avatar Sep 19 '22 11:09 gante

The documentation is not available anymore as the PR was closed or merged.

(cc @ydshieh -- this fixes a large number of scheduled CI failures)

gante avatar Sep 20 '22 11:09 gante

@gante Is this PR ready to merge? I guess so, but would like to wait your confirmation (or better for you to merge).

ydshieh avatar Sep 21 '22 12:09 ydshieh

@gante Is this PR ready to merge? I guess so, but would like to wait your confirmation (or better for you to merge).

@ydshieh It was ready -- merged now :D

gante avatar Sep 22 '22 12:09 gante