tensorrt icon indicating copy to clipboard operation
tensorrt copied to clipboard

TensorRT rebuilds ops at run time

Open nrothGIT opened this issue 5 years ago • 34 comments

At run time, I call converter.convert() converter.build(input_fn = some_input_fn) converter.save() `

And see logging to the effect of

tensorflow/core/grappler/optimizers/meta_optimizer.cc:814] Optimization results for grappler item: StatefulPartitionedCall/model_1/TRTEngineOp_1_native_segment
 2020-03-18 01:30:00.932297: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   constant_folding: Graph size after: 33 nodes (0), 32 edges (0), time = 1.199ms.
 2020-03-18 01:30:00.932302: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   layout: Graph size after: 33 nodes (0), 32 edges (0), time = 1.307ms.
2020-03-18 01:30:00.932307: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   constant_folding: Graph size after: 33 nodes (0), 32 edges (0), time = 1.228ms.
...
2020-03-18 01:30:37.946847: I tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:736] Building a new TensorRT engine for StatefulPartitionedCall/model_1/TRTEngineOp_1 with input shapes: [[25,64,4,4], [25,64,4,4]]

However, at run time, after loading the above saved model, and after the model is run a couple times, I then see that the model tries to build a new engine.

2020-03-18 01:31:58.716377: I tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:736] Building a new TensorRT engine for PartitionedCall/StatefulPartitionedCall/model_1/TRTEngineOp_1 with input shapes: [[25,64,4,4], [25,64,4,4]]

In particular, the PartitionedCall nests the previous engine name. Any thoughts on why/how to debug why a new engine would be made despite the build call. I have verified the inputs sizes are the same and you can see that in the logs.

EDIT: It appears this is a semi-known issue that is meant to be handled by name canonicalization here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc#L869 , but that logic is not being correctly applied here as PartitionedCall/StatefulPartitionedCall/model_1//TRTEngineOp_1 and StatefulPartitionedCall/model_1/TRTEngineOp_1 are not matched to the same engine despite having the same postfix.

nrothGIT avatar Mar 18 '20 01:03 nrothGIT

I am having the same issue. I need a low latency inference for a stateful recurrent network and the engine rebuilding adds a lot of latency to it. Also after rebuilding my throughput drops after rebuilding. My application is running with ~40FPS and after the engine rebuilding it drops to ~30FPS.

Tried to update to tf-2.2-rc1, it still has the same issue.

cbodenst avatar Mar 26 '20 12:03 cbodenst

@nrothGIT did you find any workaround using the information you posted?

cbodenst avatar Apr 03 '20 07:04 cbodenst

CC @bixia1

sanjoy avatar Apr 06 '20 04:04 sanjoy

@cbodenst, Sadly I haven't had the time to dig a ton further on this, so have no good fix at the moment. I'll post back if I make any progress.

nrothGIT avatar Apr 06 '20 18:04 nrothGIT

Thanks for the report. It should not try to rebuild the engines again, but it is difficult to say anything more at this stage. It might help if you can post the full logs with level 2:

TF_CPP_VMODULE=trt_engine_op=2,convert_nodes=2,convert_graph=2,segment=2,trt_shape_optimization_profiles=2,trt_engine_resource_ops=2 python my_script.py 

It would be better if you can share a reproducer, so that we can have a closer look.

tfeher avatar Apr 06 '20 18:04 tfeher

@tfeher I will take a look into a reproducer, but I am not sure how build it since tensorrt optimizations are strongly hardware dependent.

I attached a log file. You can see there is a rebuilding at the beginning of the script. After another 2 Minutes it is building the engine again. You can just grep for "Building TensorRT engine"

I also figured out why the model is slower after rebuilding. tf-2.2.0rc2 fails to build a node and is using the native one. With tf-2.1.0 the whole model got properly converted. In the test you see in the logs I converted the model directly with tf-2.2.0rc2.

My specs: Cuda: 10.1 CUDNN: 7.6 TensorRT: 6.5 tensorflow: 2.2.0rc2

log.zip

cbodenst avatar Apr 08 '20 12:04 cbodenst

Hello, I'm facing the similar issue while trying to run multiple TensorRT-optimized models in TF2.

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

TENSORRT_WEIGHTS_ROOT = '/path/to/weights/root/'

class TensorRTModel:
    def __init__(self, model_name, weights_root=TENSORRT_WEIGHTS_ROOT):
        self.weights_root = weights_root
        self.model_name = model_name
        self.model, self.infer = self.__load_tensor_rt_model(model)

    def __load_tensor_rt_model(self, model):
        saved_model_loaded = tf.saved_model.load(os.path.join(
                self.weights_root, 
                self.model_name, 
                f'{self.model_name}_Conversions.TRT_FP32_MODEL'),
            tags=[tag_constants.SERVING])
        infer = saved_model_loaded.signatures['serving_default']
        return saved_model_loaded, infer

class Model1(TensorRTModel):
    # some model specific code

class Model2(TensorRTModel):
    # some model specific code

if __name__ == '__main__':
    model1 = Model1('model1')
    model2 = Model2('model2')

    for i in range(1000):
        print(i)
        model1_input = tf.constant(np.random.normal(size=(8, 640, 640, 3)).astype(np.float32))
        result = model1.infer(model1_input)
        model2_input = tf.constant(np.random.normal(size=(8, 224, 224, 3)).astype(np.float32))
        result = model2.infer(model2_input)

Model1 TensorRT optimizes to 2 ops: PartitionedCall/TRTEngineOp_1 input shapes: [[8,640,640,3]] PartitionedCall/TRTEngineOp_0 input shapes: [[8,1536,18,18]]

Model2 TensorRT optimizes to 1 op: PartitionedCall/TRTEngineOp_0 input shapes: [[8,224,224,3]]

At runtime engine for PartitionedCall/TRTEngineOp_0 gets rebuilt at each iteration, while for PartitionedCall/TRTEngineOp_1 engine is built only once.

Logs attached.

Environment: nvcr.io/nvidia/tensorflow:19.12-tf2-py3 docker container

log.txt

Is it possible to manually set Op name/prefix so those two Ops would not be considered the same entity?

andriiyurkiv avatar Apr 08 '20 12:04 andriiyurkiv

Is it possible to manually set Op name/prefix so those two Ops would not be considered the same entity?

We merged a change for this on April 3. Can you update your tree to check that again?

bixia1 avatar Apr 08 '20 15:04 bixia1

We merged a change for this on April 3. Can you update your tree to check that again?

Can you share a link to the corresponding PR?

andriiyurkiv avatar Apr 09 '20 09:04 andriiyurkiv

https://github.com/tensorflow/tensorflow/blob/c3e12bcd6d01f3bc92cbb99f3c91fc0bec3a8562/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc#L674-L675

https://github.com/tensorflow/tensorflow/commit/21de77485f9230eff364a4f0c8e0a55857965223

bixia1 avatar Apr 09 '20 15:04 bixia1

@bixia1 @tfeher I've tried to run the code snipped I shared with you using the latest version of the TF master branch. As you can see from the logs, now it does not output any TensorRT engine building - related information. But at the same time, the timeout remains the same as if the engines were rebuilt at each iteration.

Those two models were optimized using the new version of TF as well.

log_latest_master.txt

andriiyurkiv avatar Apr 13 '20 09:04 andriiyurkiv

I tested again with tensorflow 2.2.0rc3 and it is still rebuilding. The program runs fine for 10 Minutes but then it suddenly rebuilds the engine.

cbodenst avatar Apr 29 '20 18:04 cbodenst

https://github.com/tensorflow/tensorflow/blob/c3e12bcd6d01f3bc92cbb99f3c91fc0bec3a8562/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc#L674-L675

tensorflow/tensorflow@21de774

I also tried to build the current master (f98da6aa76ab991fb06f2de24ee7a8b87348d66e). It is still rebuilding the engines.

cbodenst avatar Apr 30 '20 17:04 cbodenst

Are there any activities about that issue?

cbodenst avatar May 29 '20 09:05 cbodenst

I will look into this, it is the next thing on my todo list.

bixia1 avatar May 29 '20 17:05 bixia1

andriiyurkiv Can you post a simplified version of your test case?

bixia1 avatar Jun 01 '20 17:06 bixia1

Can one of you post a reproducer for this?

bixia1 avatar Jun 03 '20 15:06 bixia1

@bixia1 I will work on that

cbodenst avatar Jun 03 '20 15:06 cbodenst

Please try this code. Im not sure if this is hardware dependent but it builds the engine twice. Once at converter.build and once when the itteration starts. But I did not get it to rebuild again after multiple predictions, as it is the case in my production code.

from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Flatten, Conv2D, Dense, Input
from tensorflow.keras.models import load_model
import numpy as np
import tensorflow as tf

data = tf.constant(np.empty([1,200, 200,3],np.float32))

def gen():
    yield [data]

def create_model():
    i = Input((200,200,3), batch_size=1)
    x =  Conv2D(16,(3,3))(i)
    x =  Flatten()(x)
    x =  Dense(10,activation="relu")(x)
    x =  Dense(10,activation="relu")(x)
    m =  Model(i, x)
    m.compile("sgd", "mse")
    model_path = "my_model"
    m.save(model_path,save_format=tf)
    print("model created")
    return model_path

def convert_to_trt(input_path: str, output_path: str, data_gen=None, precision="FP32"):
    conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
    conversion_params = conversion_params._replace(precision_mode=precision)
    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=input_path, conversion_params=conversion_params
    )
    converter.convert()
    print("build engine ...")
    if data_gen:
        converter.build(data_gen)
    converter.save(output_path)
    print("model converted")

if __name__ == "__main__":
    input_path = create_model()
    output_path = "trt_model"
    convert_to_trt(input_path, output_path, gen)
    model = load_model(output_path).signatures["serving_default"]
    for i in range(100000):
        print("round:",i)
        model(data)

Here a part of the output:

build engine ...
2020-06-04 16:16:39.485947: I tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:982] Building a new TensorRT engine for TRTEngineOp_0 with input shapes: [[1,200,200,3]]
2020-06-04 16:16:39.486004: I tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:1153] Linked TensorRT version: 6.0.1
2020-06-04 16:16:39.486091: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer.so.6
2020-06-04 16:16:39.486109: I tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:1154] Loaded TensorRT version: 6.0.1
2020-06-04 16:16:39.487998: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer_plugin.so.6
model converted
round: 0
2020-06-04 16:16:43.393328: I tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:982] Building a new TensorRT engine for PartitionedCall/TRTEngineOp_0 with input shapes: [[1,200,200,3]]

cbodenst avatar Jun 04 '20 16:06 cbodenst

I don't see the problem using the test case provided in comment #11.

bixia1 avatar Jun 16 '20 00:06 bixia1

@bixia1 The problem is that the TRT Engine got build twice because of the different name prefix. The TRT-Engine-Rebuild could also happen again after a lot of predictions, but this is not simple to reproduce. In general, this prevents high throughput use cases.

Also, please take a look into the log files I posted some time ago,

@tfeher I will take a look into a reproducer, but I am not sure how build it since tensorrt optimizations are strongly hardware dependent.

I attached a log file. You can see there is a rebuilding at the beginning of the script. After another 2 Minutes it is building the engine again. You can just grep for "Building TensorRT engine"

I also figured out why the model is slower after rebuilding. tf-2.2.0rc2 fails to build a node and is using the native one. With tf-2.1.0 the whole model got properly converted. In the test you see in the logs I converted the model directly with tf-2.2.0rc2.

My specs: Cuda: 10.1 CUDNN: 7.6 TensorRT: 6.5 tensorflow: 2.2.0rc2

log.zip

cbodenst avatar Jun 16 '20 09:06 cbodenst

@cbodenst Can you point out which one is a rebuild? I use command "grep -e "Building a" -e "Finished" log" to process your log and don't see which one is a rebuilt. See the file I attach here, they seems to have different shapes. processed.txt

bixia1 avatar Jun 16 '20 18:06 bixia1

@cbodenst Can you point out which one is a rebuild? I use command "grep -e "Building a" -e "Finished" log" to process your log and don't see which one is a rebuilt. See the file I attach here, they seems to have different shapes. processed.txt

in your processed.txt you can see the rebuilt at line 4, timestamp "2020-04-08 11:40:48.633176". 12 minutes after the processing loop has started

cbodenst avatar Jun 16 '20 22:06 cbodenst

I see that the engine is not created successfully: 2020-04-08 11:28:58.582814: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1002] Engine creation for StatefulPartitionedCall/TRTEngineOp_2 failed. The native segment will be used instead. Reason: Unimplemented: Transpose too large:10437228

bixia1 avatar Jun 16 '20 22:06 bixia1

@cbodenst can you try setting this env variable to see if that help? TF_DEBUG_TRT_ALLOW_INEFFICIENT_TRANSPOSE=1

bixia1 avatar Jun 17 '20 18:06 bixia1

I already tried that. This helps to build the TRT-engine properly but still results in rebuilding after some time.

cbodenst avatar Jun 19 '20 09:06 cbodenst

You meant, even though the engine is built successfully, it will be rebuilt after a while? Can you provide a log to show that?

bixia1 avatar Jun 19 '20 17:06 bixia1

I made another run with TF_DEBUG_TRT_ALLOW_INEFFICIENT_TRANSPOSE=1. You can see that the engine gets rebuild at timestamp: 2020-06-22 15:24:16.430878. Also the rebuild fails sometimes as you can see in the logs. log.zip

cbodenst avatar Jun 22 '20 16:06 cbodenst

@bixia1 are you still on it? I actually really liked to use tf-trt, but if this is not working I need to invest time into alternatives. Didn't @nrothGIT described at his very first post what the issue is? Or is this unrelated?

cbodenst avatar Jul 08 '20 09:07 cbodenst

@cbodenst Yes, I will look at this again soon.

bixia1 avatar Jul 08 '20 15:07 bixia1