segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

Cannot export onnx

Open YaoJiawei329 opened this issue 2 years ago • 22 comments
trafficstars

An Extraordinary work! Well, I try to export onnx, but error occurs. If opset=11, 12, 13, error message is: RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub. else if opset=14, 15, 16, 17, error message is: ValueError: Unsupported ONNX opset version: 14

win11 12700H 3070ti-laptop pytorch1.8.2 onnx1.12

YaoJiawei329 avatar Apr 07 '23 03:04 YaoJiawei329

I get the same error: ValueError: Unsupported ONNX opset version: 17

AllenZYJ avatar Apr 07 '23 04:04 AllenZYJ

I create a new conda env, use pytorch=1.12, and opset=1, solve the problem.

YaoJiawei329 avatar Apr 07 '23 05:04 YaoJiawei329

Try PyTorch 2.0. The requirements are likely PyTorch 2.0 and opset version 17.

HighPoint avatar Apr 07 '23 05:04 HighPoint

I got the same problem,17 to 12 is ok,but got new problem:torch_C.value object is not iterable. a problem about pytorch version?

torch:1.10 py:3.9.2

julinfn avatar Apr 07 '23 05:04 julinfn

Exporing onnx model to out/dd.onnx... Traceback (most recent call last): File "/home/ubuntu/seg/segment-anything/scripts/export_onnx_model.py", line 180, in run_export( File "/home/ubuntu/seg/segment-anything/scripts/export_onnx_model.py", line 154, in run_export torch.onnx.export( File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/init.py", line 275, in export return utils.export(model, args, f, export_params, verbose, training, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 88, in export _export(model, args, f, export_params, verbose, training, input_names, output_names, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 689, in _export _model_to_graph(model, args, verbose, input_names, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 463, in _model_to_graph graph = _optimize_graph(graph, operator_export_type, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 200, in _optimize_graph graph = torch._C._jit_pass_onnx(graph, operator_export_type) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/init.py", line 313, in _run_symbolic_function return utils._run_symbolic_function(*args, **kwargs) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 994, in _run_symbolic_function return symbolic_fn(g, *inputs, **attrs) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/symbolic_opset11.py", line 922, in repeat_interleave return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/symbolic_opset9.py", line 2064, in repeat_interleave for idx, r_split in enumerate(r_splits): TypeError: 'torch._C.Value' object is not iterable (Occurred when translating repeat_interleave).

wavelet2008 avatar Apr 07 '23 06:04 wavelet2008

It seems that it only supports pytorch version 2.0(cuda11.7). I updated my cuda and pytorch(failed with others versions)and it works. By the way, I modified onnxruntime.inferenceSession parameters at line168

ort_session = onnxruntime.InferenceSession(output,providers=['CUDAExecutionProvider'])

MolianWH avatar Apr 07 '23 07:04 MolianWH

It is not throwing any errors now,when I updated pytorch to 2.0 and onnx 1.13.1.

AllenZYJ avatar Apr 07 '23 09:04 AllenZYJ

Hey,guys! In this PR: https://github.com/facebookresearch/segment-anything/pull/210 After changing torch.repeat_interleave() to torch.expand(),, I successfully exported it under torch1.8.2+opset=12, But I'm not sure how this will affect performance.

UNeedCryDear avatar Apr 15 '23 07:04 UNeedCryDear

@UNeedCryDear . It deed works.

lauraset avatar Apr 20 '23 02:04 lauraset

Hey,guys! In this PR: #210 After changing torch.repeat_interleave() to torch.expand(),, I successfully exported it under torch1.8.2+opset=12, But I'm not sure how this will affect performance. @UNeedCryDear Hello!Which python file is this function (torch.repeat_interleave()) in,Can you tell me the location of this file?thanks a lot!

InterstellarFang avatar May 26 '23 03:05 InterstellarFang

@UNeedCryDear Hello!Which python file is this function (torch.repeat_interleave()) in,Can you tell me the location of this file?thanks a lot!

https://github.com/facebookresearch/segment-anything/pull/210/files

UNeedCryDear avatar May 26 '23 05:05 UNeedCryDear

@UNeedCryDear Hello!Which python file is this function (torch.repeat_interleave()) in,Can you tell me the location of this file?thanks a lot!

https://github.com/facebookresearch/segment-anything/pull/210/files

@UNeedCryDearThank you! I made changes based on the code you provided(https://github.com/facebookresearch/segment-anything/pull/210/files), add four lines of code, but still reported an error(RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.) under torch1.8.1+opset=12

InterstellarFang avatar May 26 '23 08:05 InterstellarFang

image

The code with a pink background has been replaced and you need to remove it.

UNeedCryDear avatar May 26 '23 08:05 UNeedCryDear

image

The code with a pink background has been replaced and you need to remove it.

@UNeedCryDear I have added # to these two sentences with a pink background,but still reported an error(RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.)under torch1.8.1+opset=12 Is this a problem with the torch version? I would like to confirm if changing the onnx opt default value is in these two files(notebooks/onnx_model_example.ipynb,scripts/export_onnx_model.py)

InterstellarFang avatar May 26 '23 09:05 InterstellarFang

show me the code you modefied.

UNeedCryDear avatar May 26 '23 09:05 UNeedCryDear

show me code you modefied.

    # Expand per-image data in batch direction to be per-mask
    # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    src_shape = (tokens.shape[0],*image_embeddings.shape[1:])
    src = image_embeddings.expand(src_shape)
    src = src + dense_prompt_embeddings
    # pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    pos_src_shape = (tokens.shape[0],*image_pe.shape[1:])
    pos_src = image_pe.expand(pos_src_shape)
    b, c, h, w = src.shape

@UNeedCryDear

InterstellarFang avatar May 26 '23 09:05 InterstellarFang

image

Search for repeat_interleave in the project, only here was calling. Your code is correct unless it is different from the function you are calling. So, have you saved your modifications?

UNeedCryDear avatar May 26 '23 09:05 UNeedCryDear

image

Search for repeat_interleave in the project, only here was calling. Your code is correct unless it is different from the function you are calling. So, have you saved your modifications?

@UNeedCryDear yes,I saved my modifications.Should I try other values besides 12(onnx opset)?

InterstellarFang avatar May 26 '23 09:05 InterstellarFang

image

Search for repeat_interleave in the project, only here was calling. Your code is correct unless it is different from the function you are calling. So, have you saved your modifications?

@UNeedCryDear hello!I don't know how to fix the error,can you give me some advice?thanks a lot!

InterstellarFang avatar May 27 '23 03:05 InterstellarFang

hello!I don't know how to fix the error,can you give me some advice?thanks a lot!

According to the error, it is a problem that was not successfully modified.You can search it like me throughout the project and find out where the modifications were not made correctly At the same time, if you are using Jupyter Notebook and colab, you may encounter issues with modified files being inconsistent with the actual running files. The correct approach is to git clone the code and make local modifications instead of pips. image

Finally, I will provide you .py that I can export which based on the modification of PR210, If you still cannot export, I suggest you try changing to PyTorch2.0. Good luck!

python export.py --checkpoint path/to/checkpoint --type vit_b --opset 12

import torch
import warnings
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel
import argparse
import onnx


def export_onnx(
        sam_checkpoint="sam_vit_b_01ec64.pth",
        model_type = "vit_b",
        opset=12,
        onnx_model_path="sam_onnx_example_maskdeocde.onnx"):
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    onnx_model = SamOnnxModel(sam, return_single_mask=True)
    dynamic_axes = {
        "point_coords": {1: "num_points"},
        "point_labels": {1: "num_points"},
    }

    embed_dim = sam.prompt_encoder.embed_dim
    embed_size = sam.prompt_encoder.image_embedding_size
    mask_input_size = [4 * x for x in embed_size]
    img_size=sam.image_encoder.img_size

    img=torch.randn(1, 3, img_size,img_size, dtype=torch.float)
    dynamic_shape = {'images': {0: 'batch', 2: 'height', 3: 'width'}}
    dummy_inputs = {
        "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
        "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
        "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
        "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
        "has_mask_input": torch.tensor([1], dtype=torch.float),
        "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
    }
    output_names = ["masks", "iou_predictions", "low_res_masks"]

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        warnings.filterwarnings("ignore", category=UserWarning)
        with open(onnx_model_path, "wb") as f:
            torch.onnx.export(
                onnx_model,
                tuple(dummy_inputs.values()),
                f,
                export_params=True,
                verbose=False,
                opset_version=opset,
                do_constant_folding=True,
                input_names=list(dummy_inputs.keys()),
                output_names=output_names,
                dynamic_axes=dynamic_axes,
            )
    model_onnx = onnx.load(onnx_model_path)  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model
    print("Done!")

def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, default ="./model_weights/sam_vit_b_01ec64.pth", help="The path to the SAM model checkpoint.")
    parser.add_argument("--type", type=str, default="vit_b",
                        help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.")
    parser.add_argument("--opset",type=int,default=12,help="The ONNX opset version to use")
    parser.add_argument("--output", type=str,default ="sam_onnx_example_maskdeocde.onnx", help="The ONNX opset version to use")
    opt = parser.parse_args()
    return opt
if __name__ == '__main__':
    opt = parse_opt()
    export_onnx(opt.checkpoint, opt.type,opt.opset,opt.output)

UNeedCryDear avatar May 29 '23 01:05 UNeedCryDear

Hey,guys! In this PR: #210 After changing torch.repeat_interleave() to torch.expand(),, I successfully exported it under torch1.8.2+opset=12, But I'm not sure how this will affect performance.

It is true that onnx can be exported successfully, but the web demo cannot be used normally.

captainIT avatar Jun 06 '24 09:06 captainIT

It is true that onnx can be exported successfully, but the web demo cannot be used normally.

I'm sorry I can't help you,I am not familiar with the web at all. If you need to use the web side, it is best to use the original code.

UNeedCryDear avatar Jun 06 '24 09:06 UNeedCryDear