RobustSAM
RobustSAM copied to clipboard
About ONNX
I have tried convert to ONNX, but have got a lot of issues. After that seen your SamOnnxModel(nn.Module), but do not know how implement it.
As I see it have to be something like that
ModelToExport = SamOnnxModel(model= model, return_single_mask = True)
dummy_image_embeddings = torch.randn(1, 3, 1024, 1024, device='cuda', requires_grad=True)
dummy_point_coords = torch.randn(1, 1, 2, device='cuda', requires_grad=True)
dummy_point_labels = torch.randn(1, 1, device='cuda', requires_grad=True)
dummy_mask_input = torch.randn(1, 1, 1200, 1200, device='cuda', requires_grad=True) # ??
dummy_has_mask_input = torch.randn(1, 1024, 1024, 3, device='cuda', requires_grad=True) # ??
dummy_orig_im_size = torch.randn(1200, 1200, device='cuda', requires_grad=True)
inputs = ['image_embeddings', 'point_coords', 'point_labels',
'mask_input', 'has_mask_input', 'orig_im_size']
outputs = ['upscaled_masks', 'scores', 'masks']
torch.onnx.export(ModelToExport,
(dummy_image_embeddings, dummy_point_coords, dummy_point_labels,
dummy_mask_input, dummy_has_mask_input, dummy_orig_im_size),
opt.checkpoint_path.replace('pth','onnx'),
export_params=True, do_constant_folding=True,
input_names=inputs, output_names=outputs, opset_version=19,
verbose=False)
But I do not know how used mask_input and has_mask_input. This inputs do not included in base predict of for torch model.