samexporter icon indicating copy to clipboard operation
samexporter copied to clipboard

After converting the SAM2 model to ONNX , the inference results are significantly worse than the original model.

Open lhz123 opened this issue 1 year ago • 7 comments

lhz123 avatar Aug 09 '24 09:08 lhz123

@lhz123 can you provide any comparison examples?

cile98 avatar Aug 09 '24 10:08 cile98

Hi! :D The inference flow is not identical to the original one from SAM2 official repo, including the dimension of the input images. Therefore, the results cannot be comparable.

vietanhdev avatar Aug 10 '24 04:08 vietanhdev

@vietanhdev I've noticed all of the converted SAM2 models output a mask in a 256x256 resolution. Is this configurable? Ideally I want it to be the same as the input resolution (1024x1024).

The reason 256 isn't good enough is that after upscaling to 1024, the edges are very rough and don't overlay perfectly with the source image. I've applied some basic post processing, but the result isn't very accurate, especially for small object/surfaces.

Does the original SAM model output masks in 256 res? What are the limitations that make the onnx version different from the pytorch one?

marwand avatar Aug 16 '24 17:08 marwand

@vietanhdev I've noticed all of the converted SAM2 models output a mask in a 256x256 resolution. Is this configurable? Ideally I want it to be the same as the input resolution (1024x1024).

The reason 256 isn't good enough is that after upscaling to 1024, the edges are very rough and don't overlay perfectly with the source image. I've applied some basic post processing, but the result isn't very accurate, especially for small object/surfaces.

Does the original SAM model output masks in 256 res? What are the limitations that make the onnx version different from the pytorch one?

Pretty sure SAM1 also originally outputs them in 256x256 res and then upscales them

cile98 avatar Aug 16 '24 17:08 cile98

@vietanhdev I recommend adding masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False) to the decoder to get smoother results than doing the upscale with opencv

here is the updated colab notebook for export: https://colab.research.google.com/drive/1tqdYbjmFq4PK3Di7sLONd0RkKS0hBgId?usp=sharing

ibaiGorordo avatar Aug 17 '24 05:08 ibaiGorordo

Hi @ibaiGorordo Thank you for your great code! Could you help with a PR to this repo?

vietanhdev avatar Aug 17 '24 06:08 vietanhdev

@vietanhdev I recommend adding masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False) to the decoder to get smoother results than doing the upscale with opencv

here is the updated colab notebook for export: https://colab.research.google.com/drive/1tqdYbjmFq4PK3Di7sLONd0RkKS0hBgId?usp=sharing

This is great, thank you!

marwand avatar Aug 17 '24 12:08 marwand