IGEV icon indicating copy to clipboard operation
IGEV copied to clipboard

onnx模型

Open 5p6 opened this issue 5 months ago • 3 comments

请问如何将igev模型的结构和权重转成onnx模型?

5p6 avatar Feb 29 '24 02:02 5p6

如果可以,能否提供一下模型转换为onnx的代码?

5p6 avatar Feb 29 '24 03:02 5p6

+1

zhengshunkai avatar Mar 12 '24 02:03 zhengshunkai

below you can see an example of code that I use for conversion to onnx in CPU

make sure that the input_tensor is a multiplication of 64, lmk if you need a docker for the conversion environment

from core.igev_stereo import IGEVStereo
import torch
import argparse

from torchsummary import summary
import torchvision
# from tensorflow.python.compiler.tensorrt import trt_convert as trt

import argparse

class Args:
    def __init__(self):
        self.restore_ckpt = './pretrained_models/sceneflow/sceneflow.pth'
        self.save_numpy = False
        self.left_imgs = "./demo-imgs/*/im0.png"
        self.right_imgs = "./demo-imgs/*/im1.png"
        self.output_directory = "./demo-output/"
        self.mixed_precision = False
        self.valid_iters = 32
        self.hidden_dims = [128, 128, 128]
        self.corr_implementation = "reg"
        self.shared_backbone = False
        self.corr_levels = 2
        self.corr_radius = 4
        self.n_downsample = 2
        self.slow_fast_gru = False
        self.n_gru_layers = 3
        self.max_disp = 192

# Create an instance of the args class
args = Args()

model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
model = model.eval()
input_tensor = torch.randn(1, 3, 256, 256)  # Example input shape (batch_size, channels, height, width)

# Scale the tensor to be in the range [0, 255]
scaled_tensor = (input_tensor - input_tensor.min()) / (input_tensor.max() - input_tensor.min()) * 255

# Convert the tensor to float32
scaled_tensor = scaled_tensor.float()

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# input_tensor = input_tensor.to(device)
model = model.to("cpu")

print(summary(model))

scaled_tensor = scaled_tensor.to("cpu")

torch.onnx.export(model.module,                     # PyTorch model
                  (scaled_tensor,scaled_tensor),             # Example input tensor
                  "model2.onnx",         # Output ONNX file path
                  input_names=['input'],    # Input names used in the ONNX model
                  output_names=['output'],  # Output names used in the ONNX model
                  export_params=True,
                  opset_version= 16,
                  verbose=True)  

HubertBlach avatar Apr 10 '24 12:04 HubertBlach