ranking
ranking copied to clipboard
TFR-BERT How to integrate other BERT-like model
Hi all,
I am trying to develop a ranking function for Chinese text, so I have adopted some models that's been further finetuned to replace the original BERT model checkpoint. However, when I tried to input those models as checkpoints into the tfrbert_example.py
and ran it, the error message showed:
Traceback (most recent call last):
File "/content/AI-backend/ranking/./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/examples/tfrbert_example.py", line 274, in <module>
tf.compat.v1.app.run()
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 303, in run
_run_main(main, args)
File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "/content/AI-backend/ranking/./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/examples/tfrbert_example.py", line 270, in main
train_and_eval()
File "/content/AI-backend/ranking/./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/examples/tfrbert_example.py", line 266, in train_and_eval
bert_ranking_pipeline.train_and_eval(local_training=FLAGS.local_training)
File "/content/AI-backend/ranking/bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary.runfiles/org_tensorflow_ranking/tensorflow_ranking/extension/pipeline.py", line 422, in train_and_eval
tf.estimator.train_and_evaluate(self._estimator, train_spec, eval_spec)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/training.py", line 505, in train_and_evaluate
return executor.run()
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/training.py", line 646, in run
return self.run_local()
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/training.py", line 747, in run_local
saving_listeners=saving_listeners)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 349, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1175, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1208, in _train_model_default
saving_listeners)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1388, in _train_with_estimator_spec
tf.compat.v1.train.warm_start(*self._warm_start_settings)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/warm_starting_util.py", line 476, in warm_start
ckpt_to_initialize_from, grouped_variables.keys())
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/warm_starting_util.py", line 408, in _get_object_checkpoint_renames
.format(missing_names))
ValueError: Attempting to warm-start from an object-based checkpoint, but found that the checkpoint did not contain values for all variables. The following variables were missing: {'transformer/layer_2/output/bias', 'transformer/layer_4/self_attention/query/bias', 'transformer/layer_10/intermediate/kernel', 'transformer/layer_10/self_attention/query/kernel', 'transformer/layer_4/self_attention/value/kernel', 'transformer/layer_9/self_attention/attention_output/bias', 'transformer/layer_4/output_layer_norm/gamma', 'transformer/layer_1/self_attention/value/bias', 'transformer/layer_2/output_layer_norm/beta', 'transformer/layer_3/self_attention/key/kernel', 'transformer/layer_9/output_layer_norm/beta', 'transformer/layer_10/self_attention/key/bias', 'transformer/layer_9/output/bias', 'transformer/layer_10/self_attention/value/bias', 'transformer/layer_3/self_attention/query/kernel', 'transformer/layer_0/self_attention/value/bias', 'transformer/layer_5/self_attention/attention_output/kernel', 'transformer/layer_2/self_attention/attention_output/kernel', 'transformer/layer_7/self_attention/key/kernel', 'transformer/layer_11/output/kernel', 'transformer/layer_8/self_attention/attention_output/bias', 'transformer/layer_11/intermediate/kernel', 'transformer/layer_10/self_attention/value/kernel', 'transformer/layer_6/intermediate/kernel', 'transformer/layer_5/self_attention/query/bias', 'transformer/layer_9/self_attention/key/bias', 'transformer/layer_0/output_layer_norm/beta', 'embeddings/layer_norm/gamma', 'transformer/layer_2/self_attention/value/bias', 'transformer/layer_5/output/bias', 'transformer/layer_7/self_attention/attention_output/bias', 'transformer/layer_8/output/kernel', 'transformer/layer_4/self_attention/query/kernel', 'transformer/layer_11/self_attention/query/bias', 'pooler_transform/bias', 'transformer/layer_1/self_attention/value/kernel', 'transformer/layer_4/intermediate/bias', 'transformer/layer_8/self_attention/query/bias', 'transformer/layer_2/intermediate/kernel', 'transformer/layer_4/intermediate/kernel', 'transformer/layer_10/output_layer_norm/gamma', 'transformer/layer_8/self_attention/value/kernel', 'transformer/layer_10/self_attention/query/bias', 'transformer/layer_4/self_attention_layer_norm/gamma', 'transformer/layer_5/output_layer_norm/beta', 'transformer/layer_2/self_attention/key/bias', 'transformer/layer_1/self_attention/attention_output/bias', 'transformer/layer_6/self_attention/key/kernel', 'transformer/layer_0/output/bias', 'transformer/layer_1/self_attention/key/bias', 'transformer/layer_9/self_attention_layer_norm/beta', 'transformer/layer_6/output/kernel', 'transformer/layer_4/output_layer_norm/beta', 'transformer/layer_2/output_layer_norm/gamma', 'transformer/layer_5/intermediate/kernel', 'transformer/layer_8/self_attention/attention_output/kernel', 'transformer/layer_3/output/bias', 'transformer/layer_11/self_attention_layer_norm/beta', 'transformer/layer_3/self_attention/attention_output/bias', 'transformer/layer_7/self_attention/value/kernel', 'transformer/layer_8/output/bias', 'transformer/layer_3/output_layer_norm/gamma', 'transformer/layer_4/self_attention/key/bias', 'transformer/layer_6/self_attention_layer_norm/gamma', 'transformer/layer_6/self_attention/query/bias', 'transformer/layer_2/self_attention/query/kernel', 'transformer/layer_6/output_layer_norm/gamma', 'transformer/layer_0/self_attention/query/kernel', 'transformer/layer_5/self_attention_layer_norm/gamma', 'transformer/layer_5/intermediate/bias', 'transformer/layer_0/self_attention/attention_output/bias', 'transformer/layer_2/output/kernel', 'transformer/layer_5/self_attention/value/kernel', 'transformer/layer_7/output_layer_norm/gamma', 'transformer/layer_8/self_attention/key/kernel', 'transformer/layer_9/intermediate/bias', 'transformer/layer_11/output/bias', 'transformer/layer_8/output_layer_norm/beta', 'transformer/layer_0/self_attention_layer_norm/beta', 'transformer/layer_10/self_attention/key/kernel', 'transformer/layer_3/self_attention_layer_norm/gamma', 'transformer/layer_2/self_attention_layer_norm/gamma', 'transformer/layer_0/intermediate/kernel', 'transformer/layer_8/intermediate/bias', 'transformer/layer_1/output/kernel', 'transformer/layer_9/output/kernel', 'transformer/layer_0/self_attention/value/kernel', 'transformer/layer_1/output_layer_norm/gamma', 'transformer/layer_6/self_attention/attention_output/kernel', 'transformer/layer_0/self_attention/key/bias', 'transformer/layer_4/output/bias', 'type_embeddings/embeddings', 'transformer/layer_1/self_attention_layer_norm/gamma', 'transformer/layer_8/self_attention/key/bias', 'transformer/layer_9/self_attention/key/kernel', 'transformer/layer_1/self_attention/attention_output/kernel', 'embeddings/layer_norm/beta', 'transformer/layer_3/self_attention/attention_output/kernel', 'transformer/layer_7/self_attention_layer_norm/gamma', 'transformer/layer_7/self_attention/attention_output/kernel', 'transformer/layer_1/intermediate/bias', 'transformer/layer_10/self_attention/attention_output/kernel', 'transformer/layer_9/self_attention/query/kernel', 'transformer/layer_11/intermediate/bias', 'transformer/layer_5/self_attention/value/bias', 'transformer/layer_11/self_attention/value/bias', 'transformer/layer_9/self_attention/attention_output/kernel', 'transformer/layer_0/self_attention/attention_output/kernel', 'transformer/layer_4/self_attention/attention_output/bias', 'transformer/layer_4/self_attention/key/kernel', 'transformer/layer_3/intermediate/kernel', 'transformer/layer_3/output_layer_norm/beta', 'transformer/layer_3/output/kernel', 'transformer/layer_9/intermediate/kernel', 'transformer/layer_7/output/bias', 'pooler_transform/kernel', 'transformer/layer_3/self_attention/value/kernel', 'transformer/layer_0/self_attention/query/bias', 'transformer/layer_11/self_attention/key/kernel', 'transformer/layer_10/self_attention_layer_norm/gamma', 'transformer/layer_5/output_layer_norm/gamma', 'transformer/layer_10/self_attention_layer_norm/beta', 'transformer/layer_6/self_attention_layer_norm/beta', 'transformer/layer_9/output_layer_norm/gamma', 'transformer/layer_0/output/kernel', 'transformer/layer_10/intermediate/bias', 'transformer/layer_5/self_attention/attention_output/bias', 'transformer/layer_6/self_attention/value/bias', 'transformer/layer_7/output_layer_norm/beta', 'transformer/layer_5/self_attention/key/kernel', 'transformer/layer_2/self_attention_layer_norm/beta', 'transformer/layer_3/self_attention/query/bias', 'transformer/layer_6/output/bias', 'transformer/layer_4/self_attention_layer_norm/beta', 'transformer/layer_7/self_attention/value/bias', 'transformer/layer_5/self_attention/query/kernel', 'transformer/layer_3/self_attention/value/bias', 'transformer/layer_6/self_attention/value/kernel', 'position_embedding/embeddings', 'transformer/layer_0/intermediate/bias', 'transformer/layer_7/self_attention/key/bias', 'transformer/layer_8/self_attention/value/bias', 'transformer/layer_10/output_layer_norm/beta', 'transformer/layer_6/intermediate/bias', 'transformer/layer_5/self_attention/key/bias', 'transformer/layer_11/self_attention/key/bias', 'transformer/layer_1/self_attention/key/kernel', 'transformer/layer_8/self_attention_layer_norm/gamma', 'transformer/layer_1/intermediate/kernel', 'transformer/layer_5/output/kernel', 'transformer/layer_7/self_attention/query/bias', 'transformer/layer_11/output_layer_norm/beta', 'transformer/layer_1/output/bias', 'transformer/layer_6/self_attention/query/kernel', 'transformer/layer_6/self_attention/attention_output/bias', 'transformer/layer_9/self_attention/value/kernel', 'transformer/layer_7/self_attention_layer_norm/beta', 'transformer/layer_11/self_attention/attention_output/bias', 'transformer/layer_3/self_attention/key/bias', 'transformer/layer_1/self_attention/query/kernel', 'transformer/layer_1/self_attention/query/bias', 'transformer/layer_2/self_attention/query/bias', 'transformer/layer_9/self_attention/query/bias', 'transformer/layer_10/output/bias', 'transformer/layer_11/self_attention/value/kernel', 'transformer/layer_2/self_attention/value/kernel', 'transformer/layer_7/intermediate/kernel', 'transformer/layer_11/self_attention/query/kernel', 'transformer/layer_0/self_attention/key/kernel', 'transformer/layer_11/self_attention/attention_output/kernel', 'transformer/layer_8/self_attention/query/kernel', 'transformer/layer_1/output_layer_norm/beta', 'transformer/layer_6/output_layer_norm/beta', 'transformer/layer_9/self_attention_layer_norm/gamma', 'transformer/layer_8/output_layer_norm/gamma', 'transformer/layer_2/self_attention/key/kernel', 'transformer/layer_7/self_attention/query/kernel', 'transformer/layer_3/intermediate/bias', 'transformer/layer_4/self_attention/attention_output/kernel', 'transformer/layer_4/output/kernel', 'transformer/layer_5/self_attention_layer_norm/beta', 'transformer/layer_3/self_attention_layer_norm/beta', 'transformer/layer_7/output/kernel', 'transformer/layer_0/self_attention_layer_norm/gamma', 'transformer/layer_8/self_attention_layer_norm/beta', 'transformer/layer_8/intermediate/kernel', 'transformer/layer_10/output/kernel', 'transformer/layer_6/self_attention/key/bias', 'transformer/layer_9/self_attention/value/bias', 'transformer/layer_11/output_layer_norm/gamma', 'transformer/layer_2/self_attention/attention_output/bias', 'transformer/layer_1/self_attention_layer_norm/beta', 'transformer/layer_0/output_layer_norm/gamma', 'transformer/layer_7/intermediate/bias', 'word_embeddings/embeddings', 'transformer/layer_11/self_attention_layer_norm/gamma', 'transformer/layer_10/self_attention/attention_output/bias', 'transformer/layer_4/self_attention/value/bias', 'transformer/layer_2/intermediate/bias'}
I think those are the missing "standard" variables called by ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
, in the file tensorflow/python/training/warm_starting_util.py
. And I don't know how to overcome this. Meanwhile, I have read your Learning-to-Rank with BERT in TF-Ranking, and it seems that you have integrated RoBERTa and ELECTRA during the experiments. I wonder how to do that. If I may be so bold, would you develop a more generic TFR-BERT framework compatible with other BERT-like models?
Hi @WesleyHung,
Is your customized BERT checkpoint a TF2 checkpoint? Did you also provide the json file (similar to the toy example here) as the bert_config_file
argument?
As for more generic TFR-BERT framework compatible with other BERT-like models, yes, it is on our TODO list;)
@HongleiZhuang Thank you for your reply! Yes, I used it as TF2 checkpoint and it provided a config json file. Actually, that customized BERT-like model doesn't provide a checkpoint (For your reference https://huggingface.co/toastynews/electra-hongkongese-base-discriminator), so I loaded the TF model using Huggingface library and then saved it as checkpoint file.
Hi @WesleyHung,
Can you inspect the checkpoint you generated and the checkpoint downloaded from tf model garden and see if they have the same variable names? If not, you may need to generate a mapping between them and use it to obtain your own warm start setting, and pass it in like here. Another option would be to convert the pytorch model to a TensorFlow 2 checkpoint.
Let me know if any of these options work for you.