MoGe icon indicating copy to clipboard operation
MoGe copied to clipboard

Export to ONNX

Open sctrueew opened this issue 1 year ago • 34 comments

Hi, Is it possible to run it with ONNX?

I really appreciate any help you can provide.

sctrueew avatar Nov 13 '24 10:11 sctrueew

Hi. I have tried a few strategies with pytorch built-in ONNX tools but failed unfortunately. I am not familiar with ONNX and don't know how to fix it. I am sorry that I can't help : ( Maybe you can have a try with some other tools. Please let me know if you have any solution.

EasternJournalist avatar Nov 28 '24 12:11 EasternJournalist

Hi @sctrueew and @EasternJournalist , I have successfully imported the model to ONNX and validated that the outputs of the original and converted models are identical. If you want, i can submit a PR of the script. However, it did also require me to make some minor changes (mainly naming of functions) in the moge_model.py file, so we'll need to think how to bypass that neatly.

iariav avatar Jan 06 '25 13:01 iariav

@iariav I failed to imported the model to ONNX. Would you like to share how to do that?

Minnchong avatar Feb 12 '25 05:02 Minnchong

Hello, I am also trying something similar. At present, I have successfully exported the onnx model and it can support inference, but I found that onnx inference cannot support dynamic batch and dynamic resolution. Even though I have set it to add dynamic axes when exporting, some operator broadcast errors still occur when changing the size. I would like to know if you have encountered such a problem and how to solve it. @iariav

Bin-ze avatar Feb 21 '25 09:02 Bin-ze

Hi @Bin-ze , Actually, i haven't tried exporting with dynamic input shape. My data is all of a fixed size, so i just used fixed size input in the exported ONNX model.

iariav avatar Feb 23 '25 08:02 iariav

I have successfully implemented onnx export and got the same output as pytorch.

I am trying to use tensort inference, but there are still some problems with operators.

Have you tried tenort inference?

Bin-ze avatar Feb 23 '25 12:02 Bin-ze

I also get same output between pytorch and ONNX models, but when i export to TRT even at fp32 i get corrupted results. conversion is passing without warnings\errors, but results are really bad. I will try to investigate this a bit further when i have the time

iariav avatar Feb 24 '25 08:02 iariav

In fact, after today's efforts, I have implemented trt inference and got good results, inference ms: single img 0.025s, gpu 4090, Too slow for me.

I am still working on dynamic batch. If you have any progress, can you tell me?

Bin-ze avatar Feb 24 '25 15:02 Bin-ze

@iariav @Bin-ze Hi, thanks for your work. is it possible to share the code?

sctrueew avatar Feb 24 '25 15:02 sctrueew

+1 trt code would be amazing! Especially in c++. 😁

antithing avatar Feb 24 '25 19:02 antithing

@Bin-ze thanks for the update. i also got ~250ms per image, but as i said currently my results are not looking good, i'm still looking into it.

regarding dynamic shape - are you aiming for dynamic resolution, or dynamic batch size, or both?

iariav avatar Feb 25 '25 07:02 iariav

both dynamic resolution and dynamic batch size

Bin-ze avatar Feb 25 '25 14:02 Bin-ze

I'd like to know what gpu you're using to get a 250ms latency, I can get 100ms latency on a 4090 even with pytorch. @iariav

Bin-ze avatar Feb 25 '25 14:02 Bin-ze

In fact, after today's efforts, I have implemented trt inference and got good results, inference ms: single img 0.025s, gpu 4090, Too slow for me.

I am still working on dynamic batch. If you have any progress, can you tell me?

@Bin-ze Hi, Is the output of TensorRT the same as ONNX? Would it be possible for you to share the code? Thank you in advance

sctrueew avatar Feb 25 '25 19:02 sctrueew

step1: Disable anti-aliasing in the forward function

Image

step2: only export model forward func in moge:



import cv2
import torch
from moge.model import MoGeModel
import torch.nn as nn


# Load the model from huggingface hub (or load from local).
model_org = MoGeModel.from_pretrained("moge-vitl/model.pt")                       


class Model(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        self.model = model_org
        
    def forward(self, images):
        outputs = self.model(images)
        return outputs

model = Model()
data = torch.rand(1, 3, 530, 942)

dynamic_axes = {
        "images": {
            0: "N",
        },
        "points": {
            0: "N",
        },
        "mask": {
            0: "N",
        },
    }

output_file = "model.onnx"

torch.onnx.export(
    model,
    data,
    output_file,
    input_names=["images"],
    output_names=["points", "mask"],
    dynamic_axes=None,
    opset_version=17,
    verbose=False,
    do_constant_folding=False,
)


import onnx
import onnxsim

dynamic = True
# input_shapes = {'images': [1, 3, 640, 640], 'orig_target_sizes': [1, 2]} if dynamic else None
input_shapes = {"images": data.shape} if dynamic else None
onnx_model_simplify, check = onnxsim.simplify(output_file, test_input_shapes=input_shapes)
onnx.save(onnx_model_simplify, output_file)
print(f"Simplify onnx model {check}...")





step3: using onnxruntime inference , Use pre- and post-processing in moge model, Combining onnx models with pre- and post-processing, like this :

def forward(input_image, sess):
    
    image, original_height, original_width = preprocess(input_image)
    output = sess.run(
        output_names=["points", "mask"],
        input_feed={"images": image.numpy()},
    )
    return_dict = postprocess(output, original_height, original_width)
    return return_dict

step4: if want export engine , using trtexec

this is my Implementation

Bin-ze avatar Feb 26 '25 06:02 Bin-ze

@Bin-ze Thank you for your work. Could you please explain the pre-processing and post-processing steps?

   output = self.forward(image) ==> **Should be changed to the ONNX model**

   output = sess.run(
              output_names=["points", "mask"],
              input_feed={"images": image.numpy()},
          )
   def infer(
        self, 
        image: torch.Tensor, 
        force_projection: bool = True,
        resolution_level: int = 9,
        apply_mask: bool = True,
        fov_x: Union[Number, torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        User-friendly inference function

        ### Parameters
        - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
        - `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest)
        - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
        - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
        - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
            
        ### Returns

        A dictionary containing the following keys:
        - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
        - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
        - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
        """
        if image.dim() == 3:
            omit_batch_dim = True
            image = image.unsqueeze(0)
        else:
            omit_batch_dim = False

        original_height, original_width = image.shape[-2:]
        area = original_height * original_width
        aspect_ratio = original_width / original_height

        min_area, max_area = self.trained_area_range
        expected_area = min_area + (max_area - min_area) * (resolution_level / 9)
        
        if expected_area != area:
            expected_width, expected_height = int(original_width * (expected_area / area) ** 0.5), int(original_height * (expected_area / area) ** 0.5)
            image = F.interpolate(image, (expected_height, expected_width), mode="bicubic", align_corners=False, antialias=True)
        
        **output = self.forward(image)** **Should be changed to the ONNX model**

          **output = sess.run(
              output_names=["points", "mask"],
              input_feed={"images": image.numpy()},
          )**
      

        points, mask = output['points'], output.get('mask', None)

        # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
        if fov_x is None:
            focal, shift = recover_focal_shift(points, None if mask is None else mask > 0.5)
        else:
            focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
            if focal.ndim == 0:
                focal = focal[None].expand(points.shape[0])
            _, shift = recover_focal_shift(points, None if mask is None else mask > 0.5, focal=focal)
        fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
        fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 
        intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
        depth = points[..., 2] + shift[..., None, None]
        
        # If projection constraint is forced, recompute the point map using the actual depth map
        if force_projection:
            points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :])
        else:
            points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]

        # Resize the output to the original resolution
        if expected_area != area:
            points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1)
            depth = F.interpolate(depth.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)
            mask = None if mask is None else F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)
        
        # Apply mask if needed
        if self.output_mask and apply_mask:
            mask_binary = (depth > 0) & (mask > 0.5)
            points = torch.where(mask_binary[..., None], points, torch.inf)
            depth = torch.where(mask_binary, depth, torch.inf)

        if omit_batch_dim:
            points = points.squeeze(0)
            intrinsics = intrinsics.squeeze(0)
            depth = depth.squeeze(0)
            if self.output_mask:
                mask = mask.squeeze(0)

        return_dict = {
            'points': points,
            'intrinsics': intrinsics,
            'depth': depth,
        }
        if self.output_mask:
            return_dict['mask'] = mask > 0.5

sctrueew avatar Feb 26 '25 07:02 sctrueew

@Bin-ze Hi, I couldn't get the correct output. Could you please take a look?

import cv2
import torch
import torch.nn.functional as F
import onnxruntime
import numpy as np
import utils3d
from typing import *
from numbers import Number
from moge.utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift

output_file = r"model.onnx"
sess = onnxruntime.InferenceSession(output_file)
input_name = sess.get_inputs()[0].name
image_mean_np = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).astype(np.float32)
image_std_np = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).astype(np.float32)
trained_area_range = (500 * 500, 700 * 700)


def preprocess(input_image, resolution_level=9, target_size=(512, 512)): # Added target_size parameter
    original_height, original_width = input_image.shape[:2]
    image = input_image # Assuming input_image is already a numpy array in HWC format
    image = image.astype(np.float32) / 255.0
    image = image.transpose(2, 0, 1) # to CHW
    image = np.expand_dims(image, axis=0) # to NCHW
    image = (image - image_mean_np) / image_std_np

    # Force resize to target_size (e.g., 512x512) - BEFORE any other resizing logic if you want to strictly enforce 512x512
    image_torch = torch.from_numpy(image)
    image_resized_final = F.interpolate(image_torch, size=target_size, mode="bicubic", align_corners=False, antialias=True) # Resize to 512x512
    image = image_resized_final.numpy()


    # area = original_height * original_width # No need for this area calculation if we are enforcing 512x512
    # min_area, max_area = trained_area_range
    # expected_area = min_area + (max_area - min_area) * (resolution_level / 9)

    # if expected_area != area: # No need for this conditional area-based resizing if enforcing 512x512
    #     expected_width, expected_height = int(original_width * (expected_area / area) ** 0.5), int(original_height * (expected_area / area) ** 0.5)
    #     image_torch = torch.from_numpy(image)
    #     image_resized = F.interpolate(image_torch, size=(expected_height, expected_width), mode="bicubic", align_corners=False, antialias=True)
    #     image = image_resized.numpy()


    return image, original_height, original_width


def postprocess(output, original_height, original_width, original_image_shape, resolution_level=9, force_projection=True, apply_mask=True, fov_x=None):
    points_onnx, mask_onnx = output
    points = torch.from_numpy(points_onnx)
    mask = torch.from_numpy(mask_onnx) if mask_onnx is not None else None

    # No need to resize back in postprocess if we enforced 512x512 in preprocess
    # expected_height, expected_width = points.shape[1:3]
    # area = original_height * original_width
    # expected_area = expected_height * expected_width

    # if expected_area != area:
    #     points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1)
    #     if mask is not None:
    #         mask = F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)


    points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1)
    if mask is not None:
        mask = F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)


    aspect_ratio = original_width / original_height

    if fov_x is None:
        focal, shift = recover_focal_shift(points, None if mask is None else mask > 0.5)
    else:
        focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x) / 2))
        if focal.ndim == 0:
            focal = focal[None].expand(points.shape[0])
        _, shift = recover_focal_shift(points, None if mask is None else mask > 0.5, focal=focal)
    fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
    fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
    intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
    depth = points[..., 2] + shift[..., None, None]

    if force_projection:
        points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=original_width, height=original_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :])
    else:
        points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]

    return_dict = {
        'points': points.numpy(),
        'intrinsics': intrinsics.numpy(),
        'depth': depth.numpy(),
    }
    if mask is not None:
        return_dict['mask'] = (mask > 0.5).numpy()

    return return_dict


def forward(input_image, sess):
    image, original_height, original_width = preprocess(input_image)
    output_names = [output.name for output in sess.get_outputs()]
    output_feed = sess.run(
        output_names=output_names,
        input_feed={input_name: image},
    )

    output_dict = {}
    for i, output_name in enumerate(output_names):
        output_dict[output_name] = output_feed[i]

    return_dict = postprocess(output=[output_dict['points'], output_dict.get('mask')],
                             original_height=original_height,
                             original_width=original_width,
                             original_image_shape=input_image.shape[:2]) # Pass original image shape
    return return_dict


if __name__ == '__main__':
    image_path = r"04.jpg" # use your image path
    input_image_np = cv2.imread(image_path)
    input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB)

    # Run inference
    result_dict = forward(input_image_np, sess)

    points = result_dict['points']
    depth = result_dict['depth']
    intrinsics = result_dict['intrinsics']
    mask = result_dict.get('mask') # Mask is optional depending on your model export

    print("Points shape:", points.shape)
    print("Depth shape:", depth.shape)
    print("Intrinsics shape:", intrinsics.shape)
    if mask is not None:
        print("Mask shape:", mask.shape)

    # You can further process or visualize the results here, e.g., visualize depth map
    import matplotlib.pyplot as plt
    plt.imshow(depth[0]) # Assuming batch size is 1
    plt.colorbar()
    plt.title("Depth Map from ONNX Inference")
    plt.show()

    if mask is not None:
        plt.imshow(mask[0], cmap='gray') # Assuming batch size is 1
        plt.title("Mask from ONNX Inference")
        plt.show()

sctrueew avatar Feb 26 '25 10:02 sctrueew

My pre- and post-processing are copied from moge, but some optional operations are removed. For example, dynamic resolution support in moge.

Therefore, I did not upload the code for pre- and post-processing.

Moge's infer function can be abstracted into pre-processing->forward-post->processing, just like you implemented it.

I want to know "I couldn't get the correct output", Does it mean the result is incorrect or an error?

Bin-ze avatar Feb 26 '25 14:02 Bin-ze

@Bin-ze Hi,

this is the pytorch result:

Image

and this is for onnx:

Image

sctrueew avatar Feb 26 '25 14:02 sctrueew

My pre- and post-processing are copied from moge, but some optional operations are removed. For example, dynamic resolution support in moge.

Therefore, I did not upload the code for pre- and post-processing.

Moge's infer function can be abstracted into pre-processing->forward-post->processing, just like you implemented it.

I want to know "I couldn't get the correct output", Does it mean the result is incorrect or an error?

Is it possible to share your pre&post processing function or complete code for inference using onnx?

sctrueew avatar Feb 26 '25 14:02 sctrueew

Hi @Bin-ze, I did it! I really appreciate your help. Thank you! 😊

sctrueew avatar Feb 26 '25 18:02 sctrueew

Congratulations!

Bin-ze avatar Feb 27 '25 02:02 Bin-ze

If you have any progress on dynamic resolution, please let me know, quack!

Bin-ze avatar Feb 27 '25 02:02 Bin-ze

If you have any progress on dynamic resolution, please let me know, quack!

Sure, I'm working on it. did you take the test in Python or C++?

sctrueew avatar Feb 27 '25 06:02 sctrueew

@Bin-ze I got an error when I converted to RT. This is the error: Assertion failed: (axes.allValuesKnown()) && "This version of TensorRT does not support dynamic axes.". do you have any suggestions?

sctrueew avatar Feb 27 '25 08:02 sctrueew

Maybe you should use a higher version of trt? I don't have this problem with 8.6.1

Bin-ze avatar Feb 27 '25 13:02 Bin-ze

Maybe you should use a higher version of trt? I don't have this problem with 8.6.1

I have downloaded TensorRT version 8.6.1.6, but I am still encountering this error. Assertion failed: (axes.allValuesKnown()) && "This version of TensorRT does not support dynamic axes."

sctrueew avatar Feb 28 '25 11:02 sctrueew

@Bin-ze Hi, could you please share the command for generating trt?

sctrueew avatar Mar 04 '25 06:03 sctrueew

trtexec --onnx="model.onnx" --shapes=data:1x3x530x942 --saveEngine="model.engine" --fp16

Bin-ze avatar Mar 04 '25 06:03 Bin-ze

trtexec --onnx="model.onnx" --shapes=data:1x3x530x942 --saveEngine="model.engine" --fp16

TensorRT-8.6.1.6

But, I am still encountering the error:

[03/04/2025-10:30:10] [E] [TRT] ModelImporter.cpp:772: --- Begin node --- [03/04/2025-10:30:10] [E] [TRT] ModelImporter.cpp:773: input: "/model/Cast_1_output_0" input: "/model/Unsqueeze_6_output_0" input: "/model/Unsqueeze_7_output_0" input: "/model/Unsqueeze_5_output_0" input: "/model/Unsqueeze_8_output_0" output: "/model/Slice_1_output_0" name: "/model/Slice_1" op_type: "Slice"

[03/04/2025-10:30:10] [E] [TRT] ModelImporter.cpp:774: --- End node --- [03/04/2025-10:30:10] [E] [TRT] ModelImporter.cpp:777: ERROR: builtin_op_importers.cpp:4493 In function importSlice: [8] Assertion failed: (axes.allValuesKnown()) && "This version of TensorRT does not support dynamic axes." [03/04/2025-10:30:10] [E] Failed to parse onnx file [03/04/2025-10:30:10] [I] Finished parsing network model. Parse time: 1.7587 [03/04/2025-10:30:10] [E] Parsing model failed [03/04/2025-10:30:10] [E] Failed to create engine from model or file. [03/04/2025-10:30:10] [E] Engine set up failed

sctrueew avatar Mar 04 '25 07:03 sctrueew