burn icon indicating copy to clipboard operation
burn copied to clipboard

burn_import::onnx::ModelGen hangs in build.rs if record type is NamedMpkGz

Open seftontycho opened this issue 2 years ago • 8 comments

I have the following build.rs file that I have copied from https://burn.dev/book/import/onnx-model.html.

use burn_import::onnx::ModelGen;

fn main() {
    // Generate Rust code from the ONNX model file
    ModelGen::new()
        .input("src/model/identifier.onnx")
        .out_dir("model/")
        .run_from_script();
}

When I try to build it gets to my_crate_name(build) and then never completes. I have waited for 2 hours and nothing.

No errors are thrown either. identifier.onnx is a resnet50 model that I exported from pytorch.

Any ideas what I am doing wrong?

seftontycho avatar Nov 13 '23 13:11 seftontycho

Thanks for reporting. Are you able to share your ONNX file so that we attempt to recreate the issue?

antimora avatar Nov 13 '23 13:11 antimora

Hi, thanks for the fast response. I have had to upload the onnx file to google drive as it was too large to attach.

https://drive.google.com/file/d/1pt7aOK_4zrQYW3bs6bSpRKk0Lhr5OOg6/view?usp=drive_link

seftontycho avatar Nov 13 '23 13:11 seftontycho

Yes. I can confirm the build keeps going. I am not sure what the root cause yet. I'll report back.

antimora avatar Nov 13 '23 20:11 antimora

OK. I have narrowed down the problem. If the record type is named message pack compressed (RecordType::NamedMpkGz), it will hang but other record types work. NamedMpkGz is default record type.

The following will work (the build should be fast):

ModelGen::new()
        .input("src/model/identifier.onnx")
        .out_dir("model/")
        .record_type(RecordType::NamedMpk)
        .embed_states(false)
        .run_from_script();

Or

ModelGen::new()
        .input("src/model/identifier.onnx")
        .out_dir("model/")
        .record_type(RecordType::Bincode)
        .embed_states(false)
        .run_from_script();

Or

    ModelGen::new()
        .input("src/model/identifier.onnx")
        .out_dir("model/")
        .record_type(RecordType::PrettyJson)
        .embed_states(false)
        .run_from_script();

This will hang:

    ModelGen::new()
        .input("src/model/identifier.onnx")
        .out_dir("model/")
        .record_type(RecordType::NamedMpkGz)
        .embed_states(false)
        .run_from_script();

So for now as a workaround, I recommend using NamedMpk or Bincode for record type. Use Bincode if you intend to embed the weights into your executable binary.

CCing @nathanielsimard , so he's aware of this.

antimora avatar Nov 13 '23 21:11 antimora

@seftontycho , let us know if the workaround works for you. I can confirm all ops in your ONNX should be supported (at least they are compiled). I did not run the model.

antimora avatar Nov 13 '23 21:11 antimora

I can confirm that this works! Thank you so much for the very quick fix.

seftontycho avatar Nov 13 '23 21:11 seftontycho

I can confirm that this works! Thank you so much for the very quick fix.

Great! We will investigate the root cause for NamedMpkGz and will have a fix.

antimora avatar Nov 13 '23 21:11 antimora

In #1019 PR, I changed the default recordtype to be NamedMpk as a temp measure. We will keep this ticket open till we understand what the issue is.

antimora avatar Nov 30 '23 18:11 antimora

We decided not to support compressed Gz version. Closing it.

antimora avatar Mar 01 '24 17:03 antimora