keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Make Decoding Functions Graph-compatible

Open abheesht17 opened this issue 3 years ago • 1 comments

Made an attempt to do the above here: https://colab.research.google.com/drive/1PBMzeBd-HyFE0o4VXwk19-kqHIhOZM49?usp=sharing. Ran into an issue:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-52-cee4282e867d> in <module>()
      1 inputs = keras.Input(shape=(), dtype="string")
----> 2 translated = decode_sequences(inputs)
      3 translated = tf.strings.unicode_decode(input=translated, input_encoding="UTF-8")
      4 translated = tf.strings.unicode_encode(input=translated, output_encoding="UTF-8")
      5 translated = tf.strings.regex_replace(translated, "\[PAD\]", "")

3 frames
<ipython-input-48-ace559a1a08f> in decode_sequences(input_sentences)
    151         prompt,
    152         max_length=40,
--> 153         end_token_id=spa_tokenizer.token_to_id("[END]"),
    154     )
    155     generated_sentences = spa_tokenizer.detokenize(generated_tokens)

<ipython-input-48-ace559a1a08f> in greedy_search(token_probability_fn, prompt, max_length, end_token_id, pad_token_id)
    123     if end_token_id is not None:
    124         prompt = mask_tokens_after_end_token(
--> 125             prompt, max_length, end_token_id, pad_token_id
    126         )
    127 

<ipython-input-48-ace559a1a08f> in mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id)
     36     # Build a mask including end_token and replace tokens after end_token
     37     # with `pad_token_id`.
---> 38     valid_indices = tf.sequence_mask(lengths=end_indices + 1, maxlen=max_length)
     39     return tf.where(valid_indices, prompt, pad_token_id)
     40 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

<__array_function__ internals> in result_type(*args, **kwargs)

TypeError: Cannot interpret '<KerasTensor: shape=(None,) dtype=int64 (created by layer 'tf.where_11')>' as a data type

This is most probably a bug in Keras. This is the origin of the error: https://github.com/tensorflow/tensorflow/blob/v2.9.1/tensorflow/python/ops/array_ops.py#L4510. All these .dtype calls throw an error. I'll open an issue in the Keras repository!

Also, would be glad to work on making the functions graph compatible :)

abheesht17 avatar Jun 29 '22 10:06 abheesht17

Thanks for opening! Definitely this is a good thing to work on, getting efficient function tracing support for sequence generation is important. There's actually two variants to this requirement.

  • We should be able to trace functions normally (e.g. in a keras model, or with @tf.function).
  • We should be able to trace functions with XLA (e.g. a keras model or @tf.function with jit_compile=True).

IIUC jit compilation will come with a few more restrictions on the graph code. We can do this incrementally too, start with jit_compile, add it later.

mattdangerw avatar Jun 30 '22 20:06 mattdangerw