burn icon indicating copy to clipboard operation
burn copied to clipboard

Full compatibility with Segment Anything

Open astnmsn opened this issue 10 months ago • 12 comments

I have not found any equivalent request

Feature description

Addition of support for necessary operators to utilize the vit_b SAM Model which can be found here:

Pth Direct download

ONNX file sam_vit_b.onnx.zip

I have inspected the model using Netron and compared the nodes to the support list here

Below is the list of operators in the network that are partially or fully missing from the support list Operator (missing support)

  • Mul (Import)
  • MatMul (Import)
  • Sin (Import)
  • Shape (Import)
  • Expand (Import)
  • Not (Import)
  • ReduceMean (Import)
  • ConstantOfShape (Full)
  • Where (Import)
  • Slice (Import)
  • OneHot (Import)
  • Tile (Import)
  • LayerNormalization (Import)
  • Gemm (Import)
  • ReduceMax (Import)
  • Floor (Full)

Feature motivation

I am currently trying to load and run an onnx model segment-anything inside an image editing app in an effort to provide a masking experience similar to the demo available here. I am integrating it into an existing rust codebase using wgpu that compiles to wasm and runs in the browser.

Before I can select burn as the ml library to support this workflow, I need to be sure that it supports the operators specified in the model.

astnmsn avatar Mar 27 '24 21:03 astnmsn

Thanks for filing this. This is helpful as we prioritize ONNX ops. If you have a direct link to the ONNX file, can you also link this?

antimora avatar Mar 27 '24 21:03 antimora

Updating Expand to Import, since we just added this op. I need to update the docs.

antimora avatar Mar 27 '24 21:03 antimora

Submitted a PR to fix the supported OPs document: https://github.com/tracel-ai/burn/pull/1547

antimora avatar Mar 27 '24 21:03 antimora

This is the model download link provided by the SAM repo - I have also added to the original post

astnmsn avatar Mar 27 '24 22:03 astnmsn

I think it might be faster to implement the model manually in Burn and load the pth weights file, which we now support.

You can check out an existing model to see how it's done: https://github.com/tracel-ai/models/tree/main/resnet-burn

We also have a YOLOX object detection PR in the works: https://github.com/tracel-ai/models/pull/24

@laggui has written a great tutorial on this subject: https://dev.to/laggui/transitioning-from-pytorch-to-burn-45m

Recently, we made tons of enhancements to the PyTorchFileRecorder: https://discord.com/channels/1038839012602941528/1144670451763785769/1216788417984335872

image

@laggui, @nathanielsimard, @ashdtu, would this be worth implementing ourselves? Should we move this ticket to the models repo?

antimora avatar Mar 28 '24 04:03 antimora

The community is always one step ahead 😄

We've actually discussed adding SAM to our models and this was in the plans following the release.

We still haven't decided whether we want to reimplement it and use the PyTorch file recorder to import the weights or use the ONNX import.

laggui avatar Mar 28 '24 12:03 laggui

@laggui, if we decide to work on this, I am more inclined to adding ONNX OPs. It will be biggest bang for the buck instead of spending time to come up with the model by hand (although I am not sure how complex it is).

antimora avatar Mar 28 '24 15:03 antimora

Btw, not sure if anyone has delved into the SAM code for ONNX export but it doesn't include all the operations to actually run the model for an input image. The encoder part is totally left out of the ONNX export and the exported ONNX model expects image embeddings as input.

In their example they still use their pytorch implementation to provide the embeddings to the ONNX runtime.

So even if we support the missing operations in this issue, SAM support will still not be complete. Is this what you expected @astnmsn?

laggui avatar Apr 09 '24 14:04 laggui

@laggui Thanks for asking, and yes that is expected. We plan to run the first half of the model to generate the embeddings on the backend using pytorch. Only the second half, which produces the masks from the embeddings and the cursor/click positions, will be run on the client

astnmsn avatar Apr 09 '24 15:04 astnmsn

Regarding Tile Op. We need to rename our current repeat op to repeat_dim and implement a proper repeat for all dimensions at once.

antimora avatar Apr 12 '24 18:04 antimora

Resolving this ticket will resolve https://github.com/tracel-ai/burn/issues/1560 as well.

antimora avatar Apr 18 '24 04:04 antimora

Current state of required ops based on the latest PRs:

op_type Burn Import
Add
Cast
Concat
Constant
ConstantOfShape
Conv
ConvTranspose
Cos
Div
Equal
Erf
Expand
Floor
Gather
Gemm
LayerNormalization ✔️
MatMul ✔️
Mul
Not ✔️
OneHot
Pow
Reciprocal
ReduceMax ✔️
ReduceMean ✔️
Relu
Reshape
Resize
Shape ✔️
Sin ✔️
Slice
Softmax
Sqrt
Sub
Tile
Transpose
Unsqueeze
Where ✔️

laggui avatar Apr 30 '24 15:04 laggui