bert icon indicating copy to clipboard operation
bert copied to clipboard

How can I save one model as a .pb file?

Open moonsin opened this issue 6 years ago • 10 comments

I have trained one model for one task. There are some latest .ckpt files. How can I save this model as a .pb file and read this .pb file to predict result for one sentence? I cannot find the tensor name for the input.

moonsin avatar Jan 04 '19 06:01 moonsin

Take a look into graph.pbtxt to find the node. In my case, I tried to save my BERT classifier as .pb file. I used freeze_graph. the output node is "loss/Softmax", the input node is "IteratorGetNext" for input_ids, "IteratorGetNext:1" for input_mask, "IteratorGetNext:2" for label_ids and "IteratorGetNext:3“ for segment_ids:

import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model import tag_constants
from tensorflow.core.protobuf import saver_pb2

freeze_graph.freeze_graph(input_graph=MODEL_DIR + SEP + 'graph.pbtxt', input_saver='', input_binary=False, input_checkpoint=MODEL_DIR + SEP + 'model.ckpt-'+version, output_node_names='loss/Softmax', restore_op_name=None, filename_tensor_name=None, output_graph=PB_MODEL_FILE, clear_devices=True, initializer_nodes=None, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=None, saved_model_tags=tag_constants.SERVING, checkpoint_version=saver_pb2.SaverDef.V2)

but it throws an out of range exception:

Traceback (most recent call last):
  File "fileconverter.py", line 29, in <module>
    convert(str(2504))
  File "fileconverter.py", line 24, in convert
    checkpoint_version=saver_pb2.SaverDef.V2)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 363, in freeze_graph
    checkpoint_version=checkpoint_version)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 190, in freeze_graph_with_def_protos
    var_list=var_list, write_version=checkpoint_version)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1102, in __init__
    self.build()
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1114, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1151, in _build
    build_save=build_save, build_restore=build_restore)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 773, in _build_internal
    saveables = self._ValidateAndSliceInputs(names_to_saveables)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 680, in _ValidateAndSliceInputs
    for converted_saveable_object in self.SaveableObjectsForOp(op, name):
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 654, in SaveableObjectsForOp
    variable, "", name)
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 128, in __init__
    self.handle_op = var.op.inputs[0]
  File "/data/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2128, in __getitem__
    return self._inputs[i]
IndexError: list index out of range

I tried python version 2.7 3.5 and 3.6, tf version 1.11 and 1.12. Thanks in advance for all the help.

leerumor avatar Jan 04 '19 09:01 leerumor

You may use bert-as-service to extract features using a fine-tuned model: https://github.com/hanxiao/bert-as-service/#serving-a-fine-tuned-bert-model

It will also generate a single file for the freezed model.

hanxiao avatar Jan 09 '19 09:01 hanxiao

i meet the same problem, and i have rebuilt a model without tf.example

zxp93 avatar Jan 15 '19 10:01 zxp93

so how to save a ckpt model to a pb format?

Gpwner avatar Feb 14 '19 09:02 Gpwner

redefine the network, and check graph in tensorboard, find the output node name you need.

JonyKai avatar Mar 13 '19 02:03 JonyKai

I've redefine the network using placeholder for input and optimize by sess.run(optimizer) By this change I can save&load pb file successfully. https://github.com/danielkaifeng/TF_BERT_Chinese_Article_Auto_Generation

danielkaifeng avatar Mar 28 '19 03:03 danielkaifeng

@danielkaifeng Hi, could you give more words on how to using your forked codes to save the pb file form pre-trained bert model, thanks.

cloudhuang avatar Apr 28 '19 09:04 cloudhuang

@cloudhuang https://github.com/SunYanCN/BERT-chinese-text-classification-and-deployment

SunYanCN avatar Jul 12 '19 13:07 SunYanCN

click this link: https://github.com/yajian/bert/blob/master/model_exporter.py

WillLiGitHub avatar Feb 17 '20 03:02 WillLiGitHub

@Gpwner Gpwner Did you find the solutions???

samida22 avatar Dec 07 '21 05:12 samida22