keras-io icon indicating copy to clipboard operation
keras-io copied to clipboard

Would it be possible to make lstm_seq2seq support mixed precision?

Open joshuayao opened this issue 9 months ago • 1 comments

Issue Type

Bug

Source

binary

Keras Version

2.16.0

Custom Code

No

OS Platform and Distribution

No response

Python version

3.11

GPU model and memory

No response

Current Behavior?

lstm_seq2seq.py works well with the default fp32 data type when using the legacy keras. import os os.environ["TF_USE_LEGACY_KERAS"] = "1"

Training completed with mixed precision successfully, but inference failed: Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'. import tensorflow as tf tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Standalone code to reproduce the issue or tutorial link

Just add the following code snippet at the beginning of this code example https://github.com/keras-team/keras-io/blob/master/examples/nlp/lstm_seq2seq.py.

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Relevant log output

Traceback (most recent call last):
  File "examples/nlp/lstm_seq2seq.py", line 332, in <module>
    decoded_sentence = decode_sequence(input_seq)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "examples/nlp/lstm_seq2seq.py", line 297, in decode_sequence
    output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_filehtao8ahn.py", line 15, in tf__predict_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileg32txcku.py", line 45, in tf__step_function
    outputs = ag__.converted_call(ag__.ld(model).distribute_strategy.run, (ag__.ld(run_step),), dict(args=(ag__.ld(data),)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileg32txcku.py", line 18, in run_step
    outputs = ag__.converted_call(ag__.ld(model).predict_step, (ag__.ld(data),), None, fscope_1)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file8lg3jru0.py", line 32, in tf__predict_step
    retval_ = ag__.converted_call(ag__.ld(self), (ag__.ld(x),), dict(training=False), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
    ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
    raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
    retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filetingiv6p.py", line 67, in tf____call__
    retval_ = ag__.converted_call(ag__.converted_call(ag__.ld(super), (), None, fscope).__call__, tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
    ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
    raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
    retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filedgxei06x.py", line 242, in tf____call__
    ag__.if_stmt(ag__.converted_call(ag__.ld(_in_functional_construction_mode), (ag__.ld(self), ag__.ld(inputs), ag__.ld(args), ag__.ld(kwargs), ag__.ld(input_list)), None, fscope), if_body_11, else_body_11, get_state_11, set_state_11, ('do_return', "kwargs['mask']", 'retval_', 'args', 'input_list', 'inputs', 'kwargs'), 3)
  File "/tmp/__autograph_generated_filedgxei06x.py", line 187, in else_body_11
    outputs = ag__.converted_call(ag__.ld(call_fn), (ag__.ld(inputs),) + tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 51, in error_handler
    ag__.if_stmt(ag__.converted_call(ag__.ld(hasattr), (ag__.ld(e), '_keras_call_info_injected'), None, fscope_1), if_body, else_body, get_state, set_state, (), 0)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 47, in if_body
    raise ag__.ld(e)
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
    retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqmtratyp.py", line 29, in tf__call
    retval_ = ag__.converted_call(ag__.ld(self)._run_internal_graph, (ag__.ld(inputs),), dict(training=ag__.ld(training), mask=ag__.ld(mask)), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 174, in tf___run_internal_graph
    ag__.for_stmt(ag__.ld(depth_keys), None, loop_body_4, get_state_9, set_state_9, (), {'iterate_names': 'depth'})
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 166, in loop_body_4
    ag__.for_stmt(ag__.ld(nodes), None, loop_body_3, get_state_8, set_state_8, (), {'iterate_names': 'node'})
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 165, in loop_body_3
    ag__.if_stmt(ag__.not_(continue__3), if_body_4, else_body_4, get_state_7, set_state_7, ('continue__3',), 0)
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 160, in if_body_4
    ag__.if_stmt(ag__.not_(continue__3), if_body_3, else_body_3, get_state_6, set_state_6, (), 0)
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 145, in if_body_3
    outputs = ag__.converted_call(ag__.ld(node).layer, tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file9p2pb1je.py", line 184, in tf____call__
    ag__.if_stmt(ag__.and_(lambda: ag__.ld(initial_state) is None, lambda: ag__.ld(constants) is None), if_body_7, else_body_7, get_state_8, set_state_8, ('do_return', "kwargs['constants']", "kwargs['initial_state']", 'retval_', 'self._num_constants', 'self.constants_spec', 'self.input_spec', 'self.state_spec'), 8)
  File "/tmp/__autograph_generated_file9p2pb1je.py", line 175, in else_body_7
    ag__.if_stmt(ag__.ld(is_keras_tensor), if_body_6, else_body_6, get_state_7, set_state_7, ('do_return', "kwargs['constants']", "kwargs['initial_state']", 'retval_', 'self.input_spec'), 5)
  File "/tmp/__autograph_generated_file9p2pb1je.py", line 168, in else_body_6
    retval_ = ag__.converted_call(ag__.converted_call(ag__.ld(super), (), None, fscope).__call__, (ag__.ld(inputs),), dict(**ag__.ld(kwargs)), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
    ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
    raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
    retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filedgxei06x.py", line 242, in tf____call__
    ag__.if_stmt(ag__.converted_call(ag__.ld(_in_functional_construction_mode), (ag__.ld(self), ag__.ld(inputs), ag__.ld(args), ag__.ld(kwargs), ag__.ld(input_list)), None, fscope), if_body_11, else_body_11, get_state_11, set_state_11, ('do_return', "kwargs['mask']", 'retval_', 'args', 'input_list', 'inputs', 'kwargs'), 3)
  File "/tmp/__autograph_generated_filedgxei06x.py", line 187, in else_body_11
    outputs = ag__.converted_call(ag__.ld(call_fn), (ag__.ld(inputs),) + tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 162, in error_handler
    raise ag__.converted_call(ag__.ld(new_e).with_traceback, (ag__.ld(e).__traceback__,), None, fscope_1) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
    retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 169, in tf__call
    ag__.if_stmt(ag__.not_(ag__.ld(self)._could_use_gpu_kernel), if_body_5, else_body_5, get_state_5, set_state_5, ('kwargs', 'last_output', 'outputs', 'runtime', 'states', 'inputs'), 5)
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 153, in else_body_5
    ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body_4, else_body_4, get_state_4, set_state_4, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 142, in else_body_4
    ag__.if_stmt(ag__.converted_call(ag__.ld(tf).executing_eagerly, (), None, fscope), if_body_3, else_body_3, get_state_3, set_state_3, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 134, in else_body_3
    last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(lstm_with_backend_selection), (), dict(**ag__.ld(normal_lstm_kwargs)), fscope)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filecam0bgs0.py", line 118, in tf__lstm_with_backend_selection
    ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body, else_body, get_state, set_state, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
  File "/tmp/__autograph_generated_filecam0bgs0.py", line 107, in else_body
    last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(defun_standard_lstm), (), dict(**ag__.ld(params)), fscope)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: in user code:

    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2436, in predict_function  *
        return step_function(self, iterator)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2409, in run_step  *
        outputs = model.predict_step(data)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2377, in predict_step  *
        return self(x, training=False)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler  *
        del filtered_tb
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 588, in __call__  *
        return super().__call__(*args, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler  *
        del filtered_tb
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__  *
        outputs = call_fn(inputs, *args, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/functional.py", line 514, in call  *
        return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/functional.py", line 671, in _run_internal_graph  *
        outputs = node.layer(*args, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/base_rnn.py", line 627, in __call__  *
        return super().__call__(inputs, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 560, in error_handler  *
        filtered_tb = _process_traceback_frames(e.__traceback__)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__  *
        outputs = call_fn(inputs, *args, **kwargs)
    File "/tmp/__autograph_generated_file_fd34cvd.py", line 162, in error_handler  **
        raise ag__.converted_call(ag__.ld(new_e).with_traceback, (ag__.ld(e).__traceback__,), None, fscope_1) from None
    File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
        retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 169, in tf__call  **
        ag__.if_stmt(ag__.not_(ag__.ld(self)._could_use_gpu_kernel), if_body_5, else_body_5, get_state_5, set_state_5, ('kwargs', 'last_output', 'outputs', 'runtime', 'states', 'inputs'), 5)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 153, in else_body_5
        ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body_4, else_body_4, get_state_4, set_state_4, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 142, in else_body_4
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).executing_eagerly, (), None, fscope), if_body_3, else_body_3, get_state_3, set_state_3, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 134, in else_body_3
        last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(lstm_with_backend_selection), (), dict(**ag__.ld(normal_lstm_kwargs)), fscope)
    File "/tmp/__autograph_generated_filecam0bgs0.py", line 118, in tf__lstm_with_backend_selection  **
        ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body, else_body, get_state, set_state, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
    File "/tmp/__autograph_generated_filecam0bgs0.py", line 107, in else_body
        last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(defun_standard_lstm), (), dict(**ag__.ld(params)), fscope)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/lstm.py", line 983, in standard_lstm
        last_output, outputs, new_states = backend.rnn(
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/backend.py", line 4985, in rnn
        output_time_zero, _ = step_function(
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/lstm.py", line 970, in step
        z += backend.dot(h_tm1, recurrent_kernel)

    TypeError: Exception encountered when calling layer 'lstm_1' (type LSTM).
    
    Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'.
    
    Call arguments received by layer 'lstm_1' (type LSTM):
      • inputs=tf.Tensor(shape=(None, 1, 91), dtype=bfloat16)
      • mask=None
      • training=False
      • initial_state=['tf.Tensor(shape=(None, 256), dtype=float32)', 'tf.Tensor(shape=(None, 256), dtype=float32)']

joshuayao avatar May 14 '24 01:05 joshuayao