ranking icon indicating copy to clipboard operation
ranking copied to clipboard

TFR-BERT How to integrate other BERT-like model

Open SaltwaterDev opened this issue 3 years ago • 3 comments

Hi all,

I am trying to develop a ranking function for Chinese text, so I have adopted some models that's been further finetuned to replace the original BERT model checkpoint. However, when I tried to input those models as checkpoints into the tfrbert_example.py and ran it, the error message showed:

Traceback (most recent call last):
  File "/content/AI-backend/ranking/./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/examples/tfrbert_example.py", line 274, in <module>
    tf.compat.v1.app.run()
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/content/AI-backend/ranking/./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/examples/tfrbert_example.py", line 270, in main
    train_and_eval()
  File "/content/AI-backend/ranking/./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/examples/tfrbert_example.py", line 266, in train_and_eval
    bert_ranking_pipeline.train_and_eval(local_training=FLAGS.local_training)
  File "/content/AI-backend/ranking/bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/pipeline.py", line 422, in train_and_eval
    tf.estimator.train_and_evaluate(self._estimator, train_spec, eval_spec)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/training.py", line 505, in train_and_evaluate
    return executor.run()
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/training.py", line 646, in run
    return self.run_local()
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/training.py", line 747, in run_local
    saving_listeners=saving_listeners)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 349, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1175, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1208, in _train_model_default
    saving_listeners)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1388, in _train_with_estimator_spec
    tf.compat.v1.train.warm_start(*self._warm_start_settings)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/warm_starting_util.py", line 476, in warm_start
    ckpt_to_initialize_from, grouped_variables.keys())
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/warm_starting_util.py", line 408, in _get_object_checkpoint_renames
    .format(missing_names))
ValueError: Attempting to warm-start from an object-based checkpoint, but found that the checkpoint did not contain values for all variables. The following variables were missing: {'transformer/layer_2/output/bias', 'transformer/layer_4/self_attention/query/bias', 'transformer/layer_10/intermediate/kernel', 'transformer/layer_10/self_attention/query/kernel', 'transformer/layer_4/self_attention/value/kernel', 'transformer/layer_9/self_attention/attention_output/bias', 'transformer/layer_4/output_layer_norm/gamma', 'transformer/layer_1/self_attention/value/bias', 'transformer/layer_2/output_layer_norm/beta', 'transformer/layer_3/self_attention/key/kernel', 'transformer/layer_9/output_layer_norm/beta', 'transformer/layer_10/self_attention/key/bias', 'transformer/layer_9/output/bias', 'transformer/layer_10/self_attention/value/bias', 'transformer/layer_3/self_attention/query/kernel', 'transformer/layer_0/self_attention/value/bias', 'transformer/layer_5/self_attention/attention_output/kernel', 'transformer/layer_2/self_attention/attention_output/kernel', 'transformer/layer_7/self_attention/key/kernel', 'transformer/layer_11/output/kernel', 'transformer/layer_8/self_attention/attention_output/bias', 'transformer/layer_11/intermediate/kernel', 'transformer/layer_10/self_attention/value/kernel', 'transformer/layer_6/intermediate/kernel', 'transformer/layer_5/self_attention/query/bias', 'transformer/layer_9/self_attention/key/bias', 'transformer/layer_0/output_layer_norm/beta', 'embeddings/layer_norm/gamma', 'transformer/layer_2/self_attention/value/bias', 'transformer/layer_5/output/bias', 'transformer/layer_7/self_attention/attention_output/bias', 'transformer/layer_8/output/kernel', 'transformer/layer_4/self_attention/query/kernel', 'transformer/layer_11/self_attention/query/bias', 'pooler_transform/bias', 'transformer/layer_1/self_attention/value/kernel', 'transformer/layer_4/intermediate/bias', 'transformer/layer_8/self_attention/query/bias', 'transformer/layer_2/intermediate/kernel', 'transformer/layer_4/intermediate/kernel', 'transformer/layer_10/output_layer_norm/gamma', 'transformer/layer_8/self_attention/value/kernel', 'transformer/layer_10/self_attention/query/bias', 'transformer/layer_4/self_attention_layer_norm/gamma', 'transformer/layer_5/output_layer_norm/beta', 'transformer/layer_2/self_attention/key/bias', 'transformer/layer_1/self_attention/attention_output/bias', 'transformer/layer_6/self_attention/key/kernel', 'transformer/layer_0/output/bias', 'transformer/layer_1/self_attention/key/bias', 'transformer/layer_9/self_attention_layer_norm/beta', 'transformer/layer_6/output/kernel', 'transformer/layer_4/output_layer_norm/beta', 'transformer/layer_2/output_layer_norm/gamma', 'transformer/layer_5/intermediate/kernel', 'transformer/layer_8/self_attention/attention_output/kernel', 'transformer/layer_3/output/bias', 'transformer/layer_11/self_attention_layer_norm/beta', 'transformer/layer_3/self_attention/attention_output/bias', 'transformer/layer_7/self_attention/value/kernel', 'transformer/layer_8/output/bias', 'transformer/layer_3/output_layer_norm/gamma', 'transformer/layer_4/self_attention/key/bias', 'transformer/layer_6/self_attention_layer_norm/gamma', 'transformer/layer_6/self_attention/query/bias', 'transformer/layer_2/self_attention/query/kernel', 'transformer/layer_6/output_layer_norm/gamma', 'transformer/layer_0/self_attention/query/kernel', 'transformer/layer_5/self_attention_layer_norm/gamma', 'transformer/layer_5/intermediate/bias', 'transformer/layer_0/self_attention/attention_output/bias', 'transformer/layer_2/output/kernel', 'transformer/layer_5/self_attention/value/kernel', 'transformer/layer_7/output_layer_norm/gamma', 'transformer/layer_8/self_attention/key/kernel', 'transformer/layer_9/intermediate/bias', 'transformer/layer_11/output/bias', 'transformer/layer_8/output_layer_norm/beta', 'transformer/layer_0/self_attention_layer_norm/beta', 'transformer/layer_10/self_attention/key/kernel', 'transformer/layer_3/self_attention_layer_norm/gamma', 'transformer/layer_2/self_attention_layer_norm/gamma', 'transformer/layer_0/intermediate/kernel', 'transformer/layer_8/intermediate/bias', 'transformer/layer_1/output/kernel', 'transformer/layer_9/output/kernel', 'transformer/layer_0/self_attention/value/kernel', 'transformer/layer_1/output_layer_norm/gamma', 'transformer/layer_6/self_attention/attention_output/kernel', 'transformer/layer_0/self_attention/key/bias', 'transformer/layer_4/output/bias', 'type_embeddings/embeddings', 'transformer/layer_1/self_attention_layer_norm/gamma', 'transformer/layer_8/self_attention/key/bias', 'transformer/layer_9/self_attention/key/kernel', 'transformer/layer_1/self_attention/attention_output/kernel', 'embeddings/layer_norm/beta', 'transformer/layer_3/self_attention/attention_output/kernel', 'transformer/layer_7/self_attention_layer_norm/gamma', 'transformer/layer_7/self_attention/attention_output/kernel', 'transformer/layer_1/intermediate/bias', 'transformer/layer_10/self_attention/attention_output/kernel', 'transformer/layer_9/self_attention/query/kernel', 'transformer/layer_11/intermediate/bias', 'transformer/layer_5/self_attention/value/bias', 'transformer/layer_11/self_attention/value/bias', 'transformer/layer_9/self_attention/attention_output/kernel', 'transformer/layer_0/self_attention/attention_output/kernel', 'transformer/layer_4/self_attention/attention_output/bias', 'transformer/layer_4/self_attention/key/kernel', 'transformer/layer_3/intermediate/kernel', 'transformer/layer_3/output_layer_norm/beta', 'transformer/layer_3/output/kernel', 'transformer/layer_9/intermediate/kernel', 'transformer/layer_7/output/bias', 'pooler_transform/kernel', 'transformer/layer_3/self_attention/value/kernel', 'transformer/layer_0/self_attention/query/bias', 'transformer/layer_11/self_attention/key/kernel', 'transformer/layer_10/self_attention_layer_norm/gamma', 'transformer/layer_5/output_layer_norm/gamma', 'transformer/layer_10/self_attention_layer_norm/beta', 'transformer/layer_6/self_attention_layer_norm/beta', 'transformer/layer_9/output_layer_norm/gamma', 'transformer/layer_0/output/kernel', 'transformer/layer_10/intermediate/bias', 'transformer/layer_5/self_attention/attention_output/bias', 'transformer/layer_6/self_attention/value/bias', 'transformer/layer_7/output_layer_norm/beta', 'transformer/layer_5/self_attention/key/kernel', 'transformer/layer_2/self_attention_layer_norm/beta', 'transformer/layer_3/self_attention/query/bias', 'transformer/layer_6/output/bias', 'transformer/layer_4/self_attention_layer_norm/beta', 'transformer/layer_7/self_attention/value/bias', 'transformer/layer_5/self_attention/query/kernel', 'transformer/layer_3/self_attention/value/bias', 'transformer/layer_6/self_attention/value/kernel', 'position_embedding/embeddings', 'transformer/layer_0/intermediate/bias', 'transformer/layer_7/self_attention/key/bias', 'transformer/layer_8/self_attention/value/bias', 'transformer/layer_10/output_layer_norm/beta', 'transformer/layer_6/intermediate/bias', 'transformer/layer_5/self_attention/key/bias', 'transformer/layer_11/self_attention/key/bias', 'transformer/layer_1/self_attention/key/kernel', 'transformer/layer_8/self_attention_layer_norm/gamma', 'transformer/layer_1/intermediate/kernel', 'transformer/layer_5/output/kernel', 'transformer/layer_7/self_attention/query/bias', 'transformer/layer_11/output_layer_norm/beta', 'transformer/layer_1/output/bias', 'transformer/layer_6/self_attention/query/kernel', 'transformer/layer_6/self_attention/attention_output/bias', 'transformer/layer_9/self_attention/value/kernel', 'transformer/layer_7/self_attention_layer_norm/beta', 'transformer/layer_11/self_attention/attention_output/bias', 'transformer/layer_3/self_attention/key/bias', 'transformer/layer_1/self_attention/query/kernel', 'transformer/layer_1/self_attention/query/bias', 'transformer/layer_2/self_attention/query/bias', 'transformer/layer_9/self_attention/query/bias', 'transformer/layer_10/output/bias', 'transformer/layer_11/self_attention/value/kernel', 'transformer/layer_2/self_attention/value/kernel', 'transformer/layer_7/intermediate/kernel', 'transformer/layer_11/self_attention/query/kernel', 'transformer/layer_0/self_attention/key/kernel', 'transformer/layer_11/self_attention/attention_output/kernel', 'transformer/layer_8/self_attention/query/kernel', 'transformer/layer_1/output_layer_norm/beta', 'transformer/layer_6/output_layer_norm/beta', 'transformer/layer_9/self_attention_layer_norm/gamma', 'transformer/layer_8/output_layer_norm/gamma', 'transformer/layer_2/self_attention/key/kernel', 'transformer/layer_7/self_attention/query/kernel', 'transformer/layer_3/intermediate/bias', 'transformer/layer_4/self_attention/attention_output/kernel', 'transformer/layer_4/output/kernel', 'transformer/layer_5/self_attention_layer_norm/beta', 'transformer/layer_3/self_attention_layer_norm/beta', 'transformer/layer_7/output/kernel', 'transformer/layer_0/self_attention_layer_norm/gamma', 'transformer/layer_8/self_attention_layer_norm/beta', 'transformer/layer_8/intermediate/kernel', 'transformer/layer_10/output/kernel', 'transformer/layer_6/self_attention/key/bias', 'transformer/layer_9/self_attention/value/bias', 'transformer/layer_11/output_layer_norm/gamma', 'transformer/layer_2/self_attention/attention_output/bias', 'transformer/layer_1/self_attention_layer_norm/beta', 'transformer/layer_0/output_layer_norm/gamma', 'transformer/layer_7/intermediate/bias', 'word_embeddings/embeddings', 'transformer/layer_11/self_attention_layer_norm/gamma', 'transformer/layer_10/self_attention/attention_output/bias', 'transformer/layer_4/self_attention/value/bias', 'transformer/layer_2/intermediate/bias'}

I think those are the missing "standard" variables called by ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start) , in the file tensorflow/python/training/warm_starting_util.py. And I don't know how to overcome this. Meanwhile, I have read your Learning-to-Rank with BERT in TF-Ranking, and it seems that you have integrated RoBERTa and ELECTRA during the experiments. I wonder how to do that. If I may be so bold, would you develop a more generic TFR-BERT framework compatible with other BERT-like models?

SaltwaterDev avatar Jul 01 '21 12:07 SaltwaterDev

Hi @WesleyHung,

Is your customized BERT checkpoint a TF2 checkpoint? Did you also provide the json file (similar to the toy example here) as the bert_config_file argument?

As for more generic TFR-BERT framework compatible with other BERT-like models, yes, it is on our TODO list;)

HongleiZhuang avatar Jul 01 '21 20:07 HongleiZhuang

@HongleiZhuang Thank you for your reply! Yes, I used it as TF2 checkpoint and it provided a config json file. Actually, that customized BERT-like model doesn't provide a checkpoint (For your reference https://huggingface.co/toastynews/electra-hongkongese-base-discriminator), so I loaded the TF model using Huggingface library and then saved it as checkpoint file.

SaltwaterDev avatar Jul 02 '21 04:07 SaltwaterDev

Hi @WesleyHung,

Can you inspect the checkpoint you generated and the checkpoint downloaded from tf model garden and see if they have the same variable names? If not, you may need to generate a mapping between them and use it to obtain your own warm start setting, and pass it in like here. Another option would be to convert the pytorch model to a TensorFlow 2 checkpoint.

Let me know if any of these options work for you.

HongleiZhuang avatar Jul 09 '21 22:07 HongleiZhuang