MobileSAM icon indicating copy to clipboard operation
MobileSAM copied to clipboard

Torchscript / Pytorch Mobile Support

Open cmarschner opened this issue 1 year ago • 15 comments

Description

This PR makes the model compilable using toch.jit.script() and adds a conversion tool that saves the model in a format that can be consumed by pytorch lite on iOS devices.

Changes for Torchscript:

  • Type annotations corrected or added
  • Change loops to enumerators
  • Avoid statics
  • Remove @torch.nograd (unsupported)
  • Add some type asserts and torch.jit.annotate() to facilitate type inference from Dict[str, Any]
  • Remove **kwargs (unsupported)

Pytorch mobile conversion

Example python ./scripts/convert_pytorch_mobile.py output_dir

The result can be loaded as described in https://pytorch.org/tutorials/prototype/ios_gpu_workflow.html

BUT: The current version only runs on CPU on Pytorch Mobile. The metal backend is missing strided convolution as it seems.

The caller still needs to provide input scaling and normalization, as it is done in the predictor example.

cmarschner avatar Nov 04 '23 19:11 cmarschner

Loading model... Traceback (most recent call last): File "/Users/le/Downloads/MobileSAM-cmarschner-convert/./scripts/convert_pytorch_mobile.py", line 40, in embedding_model_ts = torch.jit.script( ^^^^^^^^^^^^^^^^^ .......

File "/usr/local/lib/python3.11/site-packages/torch/jit/frontend.py", line 359, in build_param_list raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/usr/local/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 164 def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): ~~~~~~~ <--- HERE r"""Checkpoint a model or part of the model

CoderXXLee avatar Nov 11 '23 03:11 CoderXXLee

This is a limitation of TorchScript supported operators which are very limited actually. In fact, TorchScript has a lot of limitations, variable number of arguments is only one of the many things it can't do.

For example, you can't have the any keyword in your model, as this is a python function. See this similar issue, an analogous situation.

Rambling side note: This necessity to script modules is and has always been the source of headaches for developers. I acually think that pytorch's efforts for TorchScript integration might also stagnate in the future. The reason being that Executorch is the latest hotness.

Edit: I don't want to discourage of course, as I'd love to see this pull request merged ;)

cyrillkuettel avatar Nov 11 '23 03:11 cyrillkuettel

This is a limitation of TorchScript supported operators which are very limited actually. In fact, TorchScript has a lot of limitations, variable number of arguments is only one of the many things it can't do.

For example, you can't have the any keyword in your model, as this is a python function. See this similar issue, an analogous situation.

Rambling side note: This necessity to script modules is and has always been the source of headaches for developers. I acually think that pytorch's efforts for TorchScript integration might also stagnate in the future. The reason being that Executorch is the latest hotness. This might even eventually replace Pytorch Mobile, as it is not dependant on TorchScript.

Edit: I don't want to discourage of course, as I'd love to see this pull request merged ;)

Thank you. Problem solved

CoderXXLee avatar Nov 11 '23 05:11 CoderXXLee

Great! Did you manage to run it on Android / iOS @CoderXXLee ?

cyrillkuettel avatar Nov 11 '23 12:11 cyrillkuettel

Great! Did you manage to run it on Android / iOS @CoderXXLee ?

Yes, I'm currently running it on Android

CoderXXLee avatar Nov 12 '23 03:11 CoderXXLee

@cyrillkuettel You can download it here 链接: https://pan.baidu.com/s/1B_j7hBGjC5mNvYR5Q6p0Gw?pwd=1pxp 提取码: 1pxp

CoderXXLee avatar Nov 12 '23 04:11 CoderXXLee

Loading model... Traceback (most recent call last): File "/Users/le/Downloads/MobileSAM-cmarschner-convert/./scripts/convert_pytorch_mobile.py", line 40, in embedding_model_ts = torch.jit.script( ^^^^^^^^^^^^^^^^^ .......

File "/usr/local/lib/python3.11/site-packages/torch/jit/frontend.py", line 359, in build_param_list raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/usr/local/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 164 def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): ~~~~~~~ <--- HERE r"""Checkpoint a model or part of the model

This shouldn't happen (I was able to convert things successfully) - did you figure out why this happened @CoderXXLee ?

cmarschner avatar Nov 12 '23 19:11 cmarschner

does this mean, if we have a different orig_im_size, we would have to re-export the model?

 "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),

cyrillkuettel avatar Nov 12 '23 20:11 cyrillkuettel

Loading model... Traceback (most recent call last): File "/Users/le/Downloads/MobileSAM-cmarschner-convert/./scripts/convert_pytorch_mobile.py", line 40, in embedding_model_ts = torch.jit.script( ^^^^^^^^^^^^^^^^^ ....... File "/usr/local/lib/python3.11/site-packages/torch/jit/frontend.py", line 359, in build_param_list raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/usr/local/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 164 def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): ~~~~~~~ <--- HERE r"""Checkpoint a model or part of the model

This shouldn't happen

In my case it also worked, splendidly. Not sure what the error might have been.

cyrillkuettel avatar Nov 12 '23 20:11 cyrillkuettel

I was able to implement it in C++. I decided to share my project to he community Libtorch-MobileSAM-Example.

MobileSAM1_20231113_025139

cyrillkuettel avatar Nov 13 '23 02:11 cyrillkuettel

Great! Did you manage to run it on Android / iOS @CoderXXLee ?

Yes, I'm currently running it on Android

Hello, is there any code that implements TensorRT acceleration with C++ inference?

xiangw369 avatar Nov 13 '23 13:11 xiangw369

does this mean, if we have a different orig_im_size, we would have to re-export the model?

 "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),

No this must be a glitch

cmarschner avatar Nov 27 '23 10:11 cmarschner

does this mean, if we have a different orig_im_size, we would have to re-export the model?

 "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),

No this must be a glitch

It worked fine thank you. I was just wondering what the implications are that this value [1500, 2250] is fixed.

cyrillkuettel avatar Nov 27 '23 10:11 cyrillkuettel

@cmarschner thanks for doing this! I couldn't get it to run & produce output (same error as @CoderXXLee reported), but this discussion led me to the models @cyrillkuettel shared. Thanks @cyrillkuettel !!

cummins-orgs avatar Mar 07 '24 00:03 cummins-orgs

I'm glad you find it useful. I went through a lot of pain creating these😅

Link to models example-app/models/

cyrillkuettel avatar Mar 08 '24 02:03 cyrillkuettel