Inference Forward
Dear Authors, Congratulations on your excellent work. I wanted to address some open issues that have been mentioned regarding the evaluation on the pretrained weights of the REFUGE Dataset. It seems that reproducing the results as shown in the paper requires training for additional epochs by simply resuming training with the provided configuration and pipeline.
Additionally, I am particularly interested in deploying and optimizing the model for edge devices. During my attempt to export the model, I noticed that the layer "/image_encoder/trunk/patch_embed/proj/Conv" is set to bfloat16, which is unsupported by ONNX runtime. This could pose challenges when running the model on platforms such as Jetsons. I offloaded most of component for the forward and just performed forward_image then extracted features with _prepare_backbone_features and reshaped.
Bfloat16 is indeed efficient for training though it is very challenging for inference and make it harder for majority of edge devices compilers
/home/zordhahr/miniconda3/envs/medsam2/lib/python3.12/site-packages/torch/onnx/utils.py:1738: UserWarning: The exported ONNX model failed ONNX shape inference. The model will not be executable by the ONNX Runtime. If this is unintended and you believe there is a bug, please report an issue at https://github.com/pytorch/pytorch/issues. Error reported by strict ONNX shape inference: [ShapeInferenceError] (op_type:Conv, node name: /image_encoder/trunk/patch_embed/proj/Conv): X typestr: T, has unsupported type: tensor(bfloat16) (Triggered internally at ../torch/csrc/jit/serialization/export.cpp:1469.)
_C._check_onnx_proto(proto)
INFO:root:Exported ONNX model to REFUGE_MedSAM2.onnx
Exported ONNX model to REFUGE_MedSAM2.onnx
Hello, I'm curious as to how you exported to onnx. could you outline the rough steps for doing so?