tapas icon indicating copy to clipboard operation
tapas copied to clipboard

Converting tapas model to Saved Model/ONNX/TensorRT version for faster inference

Open hitz02 opened this issue 4 years ago • 5 comments

Hello Team,

Thanks for open sourcing this work!

While I was running the notebooks shared along with the models, I noticed that the prediction time takes almost a minute to return results on custom data (for larger models).

On doing a little bit of research to improve the inference time on TF models, found few ways listed below -

  1. TF checkpoints -> Saved Model -> GraphDef -> Optimized GraphDef -> SavedModel link
  2. TF checkpoints -> Saved Model -> ONNX (does some optimization by freezing graph and quantization)link
  3. TF checkpoints -> Saved Model -> TensorRT inference graph link

So as a first step, I tried converting the tapas models checkpoint to Saved Model format by referring this

But I got an error while running the code -

InvalidArgumentError: No OpKernel was registered to support Op 'TPUReplicatedInput' used by {{node input0}} with these attrs: [is_mirrored_variable=false, index=0, T=DT_INT32, N=32] Registered devices: [CPU, GPU, XLA_CPU, XLA_GPU] Registered kernels:

 [[input0]]

Let me know if this is the correct way to convert the tapas models. If not, it would be great if someone can guide me in the right direction.

Also, would like to understand if the team is planning to work on improving the inference time by optimizing the tapas models.

Thanks!

hitz02 avatar Dec 03 '20 07:12 hitz02

Does this help:

https://github.com/google-research/bert/issues/882

In particular, the response by @RyanHuangNLP:

@saberkun thank you for response, I find the solution in other issue with your answer, it is just tf.session with the tpu config, thank you very much

Maybe this needs to look like this:

  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.estimator.tpu.RunConfig(
          tpu_config=tf.estimator.tpu.TPUConfig(
              per_host_input_for_training=is_per_host))

Can you give it a try?

ghost avatar Dec 03 '20 13:12 ghost

Thanks for the quick response @thomasmueller-google

Can you help me understand where this piece of code should be plugged in?

hitz02 avatar Dec 03 '20 14:12 hitz02

In the other issue it sounded like you should pass the RunConfig or TPUConfig to the session. However, looking at some docs they don't seem to have the correct types (Session expects a proto).

Can you also use TpuEstimator.export_saved_model? Some documentation here.

ghost avatar Dec 03 '20 16:12 ghost

I think this has a solution that can be adapted for Tapas.

ghost avatar Dec 03 '20 16:12 ghost

@thomasmueller-google thanks for helping out. I tried to use the method mentioned here to quantize the model google-research/bert#882

import tensorflow.compat.v1 as tf

checkpoint_path = "tapas_sqa_large/model.ckpt" new_checkpoint_path = "quantized/model1.ckpt"

reader = tf.train.NewCheckpointReader(checkpoint_path) name_shape_map = reader.get_variable_to_shape_map() new_variable_map = {} for var_name in name_shape_map: if 'adam_v' not in var_name and 'adam_m' not in var_name: tensor = reader.get_tensor(var_name) var = tf.Variable(tensor, name=var_name) new_variable_map[var_name] = var

saver = tf.train.Saver(new_variable_map,defer_build=False) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, new_checkpoint_path)

when trying to convert the model getting the following error: /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs, save_debug_info) 1139 if not self._is_built and not context.executing_eagerly(): 1140 raise RuntimeError( -> 1141 "build() should be called before save if defer_build==True") 1142 if latest_filename is None: 1143 latest_filename = "checkpoint"

RuntimeError: build() should be called before save if defer_build==True

gowthamvenkatsairam avatar Dec 03 '20 19:12 gowthamvenkatsairam