Make Decoding Functions Graph-compatible
Made an attempt to do the above here: https://colab.research.google.com/drive/1PBMzeBd-HyFE0o4VXwk19-kqHIhOZM49?usp=sharing. Ran into an issue:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-52-cee4282e867d> in <module>()
1 inputs = keras.Input(shape=(), dtype="string")
----> 2 translated = decode_sequences(inputs)
3 translated = tf.strings.unicode_decode(input=translated, input_encoding="UTF-8")
4 translated = tf.strings.unicode_encode(input=translated, output_encoding="UTF-8")
5 translated = tf.strings.regex_replace(translated, "\[PAD\]", "")
3 frames
<ipython-input-48-ace559a1a08f> in decode_sequences(input_sentences)
151 prompt,
152 max_length=40,
--> 153 end_token_id=spa_tokenizer.token_to_id("[END]"),
154 )
155 generated_sentences = spa_tokenizer.detokenize(generated_tokens)
<ipython-input-48-ace559a1a08f> in greedy_search(token_probability_fn, prompt, max_length, end_token_id, pad_token_id)
123 if end_token_id is not None:
124 prompt = mask_tokens_after_end_token(
--> 125 prompt, max_length, end_token_id, pad_token_id
126 )
127
<ipython-input-48-ace559a1a08f> in mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id)
36 # Build a mask including end_token and replace tokens after end_token
37 # with `pad_token_id`.
---> 38 valid_indices = tf.sequence_mask(lengths=end_indices + 1, maxlen=max_length)
39 return tf.where(valid_indices, prompt, pad_token_id)
40
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
<__array_function__ internals> in result_type(*args, **kwargs)
TypeError: Cannot interpret '<KerasTensor: shape=(None,) dtype=int64 (created by layer 'tf.where_11')>' as a data type
This is most probably a bug in Keras. This is the origin of the error: https://github.com/tensorflow/tensorflow/blob/v2.9.1/tensorflow/python/ops/array_ops.py#L4510. All these .dtype calls throw an error. I'll open an issue in the Keras repository!
Also, would be glad to work on making the functions graph compatible :)
Thanks for opening! Definitely this is a good thing to work on, getting efficient function tracing support for sequence generation is important. There's actually two variants to this requirement.
- We should be able to trace functions normally (e.g. in a keras model, or with @tf.function).
- We should be able to trace functions with XLA (e.g. a keras model or @tf.function with jit_compile=True).
IIUC jit compilation will come with a few more restrictions on the graph code. We can do this incrementally too, start with jit_compile, add it later.