pytorch-realm-orqa icon indicating copy to clipboard operation
pytorch-realm-orqa copied to clipboard

Code pointer for converting the original tf checkpoint to pytorch

Open velocityCavalry opened this issue 3 years ago • 3 comments

Hello,

Thanks for creating this repository! I think it is useful! I was wondering whether you could provide the code pointer(s) for converting the original Tensorflow checkpoint to Pytorch, i.e., how do you convert gs://realm-data/cc_news_pretrained/bert/ and gs://realm-data/cc_news_pretrained/embedder to qqaatw/realm-cc-news-pretrained-bert and qqaatw/realm-cc-news-pretrained-embedder?

Thank you so much!

velocityCavalry avatar Nov 24 '21 20:11 velocityCavalry

Hi,

Thanks for interested in this repository!

Currently there is no simple conversion script that can convert pretrained TF checkpoints like cc_news_pretrained into compact RealmEmbedder and RealmKnowledgeAugEncoder . You can do it through the following code:

import logging

from transformers import RealmConfig, RealmRetriever, RealmKnowledgeAugEncoder
from transformers.models.realm.modeling_realm import logger

logger.setLevel(logging.INFO)

config = RealmConfig()

retriever = RealmRetriever.from_pretrained("./data/cc_news_pretrained/embedder/variables/variables", config=config, from_tf=True)
retriever.save_pretrained("path/to/cc_news_retriever/")

encoder = RealmKnowledgeAugEncoder.from_pretrained("./data/cc_news_pretrained/bert/variables/variables", config=config, from_tf=True)
encoder.save_pretrained("path/to/cc_news_encoder/")

For fine-tuned TF checkpoints such as orqa_nq_model_from_realm, you can utilize the following script to convert (actually this script can be used for converting pretrained TF checkpoints to PyTorch as well, but the saved models will include block_emb, a pre-computed evidence embedding matrix, and some newly initialized weights of RealmReader, resulting in huge model sizes):

python predictor.py  \
--retriever_path ./data/orqa_nq_model_from_realm/export/best_default/checkpoint/model.ckpt-300000  \
--checkpoint_path ./data/orqa_nq_model_from_realm/export/best_default/checkpoint/model.ckpt-300000  \
--saved_path ./path/to/ --question ""

FYI: There is a PR currently being reviewed; so some of aforementioned code and script might change in the future.

qqaatw avatar Nov 25 '21 07:11 qqaatw

Hi,

Thanks for interested in this repository!

Currently there is no simple conversion script that can convert pretrained TF checkpoints like cc_news_pretrained into compact RealmEmbedder and RealmKnowledgeAugEncoder . You can do it through the following code:

import logging

from transformers import RealmConfig, RealmRetriever, RealmKnowledgeAugEncoder
from transformers.models.realm.modeling_realm import logger

logger.setLevel(logging.INFO)

config = RealmConfig()

retriever = RealmRetriever.from_pretrained("./data/cc_news_pretrained/embedder/variables/variables", config=config, from_tf=True)
retriever.save_pretrained("path/to/cc_news_retriever/")

encoder = RealmKnowledgeAugEncoder.from_pretrained("./data/cc_news_pretrained/bert/variables/variables", config=config, from_tf=True)
encoder.save_pretrained("path/to/cc_news_encoder/")

For fine-tuned TF checkpoints such as orqa_nq_model_from_realm, you can utilize the following script to convert (actually this script can be used for converting pretrained TF checkpoints to PyTorch as well, but the saved models will include block_emb, a pre-computed evidence embedding matrix, and some newly initialized weights of RealmReader, resulting in huge model sizes):

python predictor.py  \
--retriever_path ./data/orqa_nq_model_from_realm/export/best_default/checkpoint/model.ckpt-300000  \
--checkpoint_path ./data/orqa_nq_model_from_realm/export/best_default/checkpoint/model.ckpt-300000  \
--saved_path ./path/to/ --question ""

FYI: There is a PR currently being reviewed; so some of aforementioned code and script might change in the future.

Hi,

Thanks for the quick reply, and I think this is super useful :)! Hopefully the PR can get merged soon! Is there a way to test the correctness of the converted checkpoints?

Out of curiosity, what's the difference between the released embedder (https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder) and retriever (https://huggingface.co/qqaatw/realm-cc-news-pretrained-retriever)?

Thanks again!

velocityCavalry avatar Nov 27 '21 17:11 velocityCavalry

Hello,

If you want to test your own question:

python predictor.py --question "Your question" \
--from_pt_finetuned \
--retriever_pretrained_name /path/to/converted_finetuned_pt_searcher \
--checkpoint_pretrained_name /path/to/converted_finetuned_pt_reader

If you want to run benchmarks, please refer to "Benchmark" section in the readme. For now there is no benchmark loaded from converted PyTorch checkpoint, all benchmarks are loaded from TF checkpoints and converted into PyTorch models on the fly.

RealmRetriever is a container model that includes two RealmEmbedders for query sequence and candidate sequence respectively, and by default the weights of two embedders are tied. Therefore, you can load an embedder's TF weights into an embedder or a retriever (the loading function adapts different namespaces internally), on the other hand, loading an embedder's PT weights into a retriever wouldn't work because the namespaces of their weights are different, that's why they were released separately.

qqaatw avatar Nov 29 '21 12:11 qqaatw