FastSAM icon indicating copy to clipboard operation
FastSAM copied to clipboard

Export FastSAM to ONNX using a wrapper class

Open mattia-z opened this issue 1 year ago • 0 comments

I'm trying to write a wrapper class as torch.nn.Module, i need to have some custom reworks of the model's output in the forward step and then export the model to ONNX. Here the wrapper class

import torch
import torch.nn as nn
from fastsam import FastSAM, FastSAMPrompt
import cv2
import matplotlib.pyplot as plt
import onnxruntime as ort
import numpy as np

class FastSAMWrapper(nn.Module):
    def __init__(self):
        super(FastSAMWrapper, self).__init__()
        self.model = FastSAM("FastSAM-s.pt")
        device = torch.device('cpu')
        self.model.to(device)

    def forward(self, img_tt):
        # Assuming x is the input image tensor
        # FastSAM may expect an image, so you may need to preprocess `x` as required by FastSAM.
        # Run the model's forward pass
        with torch.no_grad():
            image_np = img_tt.squeeze(0).permute(1, 2, 0).numpy()
            everything_results = self.model(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
            prompt_process = FastSAMPrompt(image_np, everything_results, device='cpu')
            ann = prompt_process.everything_prompt()
            # HERE i'm going to do some post-processing to ann
            if not isinstance(ann, torch.Tensor):
                ann = torch.tensor(ann)
            return ann

if i run the model as follow

 my_model = FastSAMWrapper()

image_np = cv2.imread('../images/IMG_3619.jpeg')
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
image_np = cv2.resize(image_np, (1024, 1024))
image_tensor = torch.from_numpy(image_np).float()  # Ensure the tensor is of floating point type
reshaped_tensor = image_tensor.permute(2, 0, 1)  # Change to (3, 1024, 1024)
reshaped_tensor = reshaped_tensor.unsqueeze(0)  # Add batch dimension

results = my_model(reshaped_tensor)

everything works fine and the output is correct.

But when i export the model with the following code

dummy_input = torch.randn((1, 3, 1024, 1024), device='cpu')

torch.onnx.export(
    my_model,                     # The PyTorch model
    dummy_input,               # The input tensor
    "fastsam_model.onnx",      # The output ONNX file name
    export_params=True,        # Store the trained parameter weights
    opset_version=11,          # ONNX opset version (11 is widely compatible, but you can try 12 or higher if needed)
    do_constant_folding=True,  # Whether to apply constant folding for optimization
    input_names=['input'], 
    output_names=['masks_data'], 
    dynamic_axes={'input': {0: 'image_tensor'}, 'output': {0: 'masks'}},
    verbose=True
)

the model inputs returned by the following code is empty

onnx_model = ort.InferenceSession("fastsam_model.onnx")
print("Model Inputs:", onnx_model.get_inputs())

I've also tried exporting replacing dummy_input with reshaped_tensor but nothing change.

Follow what i see when i load the model into netron Screenshot 2024-10-30 alle 18 18 48

I need to do this because the model will be used in a Unity app and we want to reduce the number of operations (the postprocessing on ann/model output) in the app

mattia-z avatar Oct 30 '24 17:10 mattia-z