torch2trt
torch2trt copied to clipboard
Python converted trt used in C++
i had converted the model from pytorch to tensorrt( by torch.save( .pth ) ), can i used it in c++? how?
Hi Wallace00V,
Thanks for reaching out! Yes, you can by serializing the TensorRT engine and executing using the TensorRT C++ runtime.
Assume you optimized your model by calling
model_trt = torch2trt(model, [data])
The model_trt
is a TRTModule
, which contains an attribute engine
which is the TensorRT engine. You can serialize this engine and save it to the path model_trt.engine
disk by calling
with open('model_trt.engine', 'wb') as f:
f.write(model_trt.engine.serialize())
This serialized engine may be used with the TensorRT C++ runtime as described in the TensorRT Developer Guide.
If you're unfamiliar with using TensorRT directly, please let me know and I'd be happy to help.
Best, John
@jaybdub thx for your answer, but i am unfamiliar with TensorRT. Please, how to load the 'model_trt.engine' in TensorRT C++ runtime?
@jaybdub all right, i figured it out just like the following:
std::stringstream gieModelStream;
gieModelStream.seekg(0, gieModelStream.beg);
std::ifstream cache( "model_trt.engine" );
gieModelStream << cache.rdbuf();
cache.close();
IRuntime* runtime = createInferRuntime(gLogger);
assert(runtime != nullptr);
gieModelStream.seekg(0, std::ios::end);
const int modelSize = gieModelStream.tellg();
gieModelStream.seekg(0, std::ios::beg);
void* modelMem = malloc(modelSize);
gieModelStream.read((char*)modelMem, modelSize);
nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(modelMem, modelSize, NULL); assert(engine != nullptr);
free(modelMem);
@jaybdub Hi. I got an error while using TensorRT C++ to deserialize the engine, which has some interpolate layers.
[TensorRT] ERROR: Cannot deserialize plugin interpolate
Segmentation fault
How to fix that?
I have the same error that is occured in deserializing the engine in C++
ERROR: Cannot deserialize plugin interpolate.
have the same error that is occured in deserializing the engine in C++. Do you solve the problem? @TheLittleBee @donghoonsagong
same problem , may be need a pluginFactory, have you solve the problem? @donghoonsagong @Raneee @TheLittleBee @Wallace00V @jaybdub
same problem I believe: [TensorRT] ERROR: getPluginCreator could not find plugin interpolatetorch2trt version 1 namespace torch2trt
@skyler1253 @xieydd @Raneee @donghoonsagong @TheLittleBee @Wallace00V @jaybdub Same problem. I can deserialize it in Python but failed in Cpp. Have you any ideas?
@skyler1253 @xieydd @Raneee @donghoonsagong @TheLittleBee @Wallace00V @jaybdub @crazybob23 same problem, do you solve this problem?
@skyler1253 @xieydd @Raneee @donghoonsagong @TheLittleBee @Wallace00V @jaybdub @crazybob23 @zbz-cool
same problem, do you solve this problem?