burn
burn copied to clipboard
bug(onnx): model created using tf2onnx panics with non-valid Ident
Describe the bug
I have a equinox jax model I want to import into burn. I used jax2tf to get a tensorflow model, then tf2onnx to obtain the onnx file. When I try to import this into burn it panics with the message:
ERROR burn_import::logger: PANIC => panicked at /home/mrobins/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/burn-import-0.16.0/src/burn/ty.rs:123:19:
"jax2tf_rhs_/pjit_silu_/Const_2:0" is not a valid Ident
I'd imagine it is the ":" or "/" characters in the ident, which seem to be used to identify blocks and outputs in the model. Is this panic expected or a bug?
To Reproduce
You can get the onnx file here: https://github.com/martinjrobins/diffsol/raw/refs/heads/workspace/examples/neural-ode-weather-prediction/rhs.onnx
Then I read it in as per the onnx example: https://burn.dev/burn-book/import/onnx-model.html
Expected behavior
The model to import without panic
Desktop (please complete the following information):
- OS: Ubuntu
- Browser chrome
- Version 22.04
Additional context Using burn v0.16.0
Following on from this, I removed all the special characters from the onnx graph nodes manually, then had an issue that the Split node was not supported in burn v0.16.0, so I upgraded to v0.17.0 using the main branch of this repo.
However, the model still fails to import with the following error:
DEBUG burn_import::burn::graph: Building the scope nodes len => '26'
ERROR burn_import::logger: PANIC => panicked at /home/mrobins/git/burn/crates/burn-import/src/formatter.rs:8:31:
Valid token tree: BadSourceCode("error: expected type, found `{`\n --> <stdin>:1:1443\n |\n1 | ... > { # [allow (unused_variables)] pub fn new (device : & B :: Device) -> Self { Self { phantom : core :: marker :: PhantomData , device : burn :: module :: Ignored (device . clone ()) , } } _blank_ ! () ; # [allow (clippy :: let_and_return , clippy :: approx_constant)] pub fn forward (& self , input1 : Tensor < B , 1 > , input2 : Tensor < B , 1 > ,) -> { let mut split_tensors = input1 . split_with_sizes ([128 , 64 , 2048 , 32 , 64 , 2 ,] , 0) ; let [split1_out1 , split1_out2 , split1_out3 , split1_out4 , split1_out5 , split1_out6] = split_tensors . try_into () . unwrap () ; let unsqueeze1_out1 : Tensor < B , 2 > = input2 . unsqueeze_dims (& [1 ,]) ; let reshape1_out1 = split1_out5 . reshape ([2 , 32 ,]) ; let reshape2_out1 = split1_out1 . reshape ([64 , 2 ,]) ; let matmul1_out1 = reshape2_out1 . matmul (unsqueeze1_out1) ; let squeeze1_out1 = matmul1_out1 . squeeze_dims (& []) ; let add1_out1 = squeeze1_out1 . add (split1_out2) ; let neg1_out1 = add1_out1 . clone () . neg () ; let exp1_out1 = neg1_out1 . exp () ; let add2_out1 = exp1_out1 . add_scalar (jax2tf_rhs__pjit_silu__Const_0) ; let div1_out1 = jax2tf_rhs__pjit_silu__Const_0 / add2_out1 ; let mul1_out1 = add1_out1 . mul_scalar (div1_out1) ; let unsqueeze2_out1 : Tensor < B , 2 > = mul1_out1 . unsqueeze_dims (& [1 ,]) ; let reshape3_out1 = split1_out3 . reshape ([32 , 64 ,]) ; let matmul2_out1 = reshape3_out1 . matmul (unsqueeze2_out1) ; let squeeze2_out1 = matmul2_out1 . squeeze_dims (& []) ; let add3_out1 = squeeze2_out1 . add (split1_out4) ; let neg2_out1 = add3_out1 . clone () . neg () ; let exp2_out1 = neg2_out1 . exp () ; let add4_out1 = exp2_out1 . add_scalar (jax2tf_rhs__pjit_silu__Const_0) ; let div2_out1 = jax2tf_rhs__pjit_silu__Const_0 / add4_out1 ; let mul2_out1 = add3_out1 . mul_scalar (div2_out1) ; let unsqueeze3_out1 : Tensor < B , 2 > = mul2_out1 . unsqueeze_dims (& [1 ,]) ; let matmul3_out1 = reshape1_out1 . matmul (unsqueeze3_out1) ; let squeeze3_out1 = matmul3_out1 . squeeze_dims (& []) ; let add5_out1 = squeeze3_out1 . add (split1_out6) ; } }\n | - while parsing this item list starting here ^ expected type - the item list ends here\n\n")
ok, after a little more digging I've found that:
- the original model (linked above) is parsed and run with no issues using
candle-onnx, so I believe it to be a correct model - my manually edited model I mentioned above is, however, incorrect. It seems like the ":0" syntax is necessary to distinguish different outputs from a node, I have a
Splitnode that has a few different outputs. However if ":0" or ":1" is left in the onnx file thenburnpanics with the original error
I note that the Split node was only added to burn 3 weeks ago, so perhaps it was not neccessary to support multiple outputs when importing models previously?
Sorry for the delayed response!
You can get the onnx file here: https://github.com/martinjrobins/diffsol/raw/refs/heads/workspace/examples/neural-ode-weather-prediction/rhs.onnx
The link is invalid, but there seems to be a script to generate it here?
Could you share the ONNX model in this issue? Might have to zip it for github to accept the upload.
I'd imagine it is the ":" or "/" characters in the ident, which seem to be used to identify blocks and outputs in the model. Is this panic expected or a bug?
If this is valid ONNX, then this is a bug. During code generation we parse the names and format them for variables but looks like not all symbols are handled.
I note that the Split node was only added to burn 3 weeks ago, so perhaps it was not neccessary to support multiple outputs when importing models previously?
As you noted, support for the split node was just added recently. Perhaps the implementation does not handle the whole specification. If you include the model I could take a look!
Yes, that script will generate the models, but I'll also attach it below. Thanks for looking into this :)
Ok so my initial hypothesis was correct.
If this is valid ONNX, then this is a bug. During code generation we parse the names and format them for variables but looks like not all symbols are handled.
The name formatting really only handles alphanumeric values.
https://github.com/tracel-ai/burn/blob/a1ca4346424197f42ab1b5bff7f9d1210b029318/crates/burn-import/src/burn/ty.rs#L67-L74
If we change the last line to replace the invalid ident values in your model, it correctly parses the model.
name.to_string().replace(":", "_").replace("/", "_") // replace ":" -> "_", "/" -> "_"
But I get another error:
ERROR burn_import::logger: PANIC => panicked at /home/laggui/workspace/burn/crates/burn-import/src/burn/node/binary.rs:158:18:
Division is supported for tensor and scalar only
In your case, the model actually has a lhs scalar and rhs tensor, which is not handled.
A quick fix
pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs },
(Type::Scalar(_), Type::Tensor(_)) => {
move |lhs, rhs| quote! { #rhs.recip().mul_scalar(#lhs) }
}
_ => panic!("Division is supported for tensor and scalar only"),
};
Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function))
}
And finally, the codegen worked.. but it's not correct 😓
error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
--> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:55:46
|
55 | ...scalar(jax2tf_rhs__pjit_silu__Const_1_0);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope
error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
--> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:58:25
|
58 | .mul_scalar(jax2tf_rhs__pjit_silu__Const_1_0);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope
error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
--> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:66:46
|
66 | ...scalar(jax2tf_rhs__pjit_silu__Const_1_0);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope
error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
--> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:69:25
|
69 | .mul_scalar(jax2tf_rhs__pjit_silu__Const_1_0);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope
error[E0308]: mismatched types
--> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:44:57
|
44 | ...= input1.split_with_sizes([256, 64, 2048, 32, 128, 4], 0);
| ---------------- ^^^^^^^^^^^^^^^^^^^^^^^^^^^- help: try using a conversion method: `.to_vec()`
| | |
| | expected `Vec<usize>`, found `[{integer}; 6]`
| arguments to this method are incorrect
|
= note: expected struct `Vec<usize>`
found array `[{integer}; 6]`
note: method defined here
--> /home/laggui/workspace/burn/crates/burn-tensor/src/tensor/api/base.rs:1360:12
|
1360 | pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) ...
| ^^^^^^^^^^^^^^^^
The SplitNode codegen seems to be incorrect here (fix should be easy as it currently passes an array instead of a vec), but then we will get stuck at the constant values as per #1882.
Reported on the discord: https://discord.com/channels/1038839012602941528/1144670451763785769/1373027679418454016
Hi ! I'm trying to import this onnx model
ERROR burn_import::logger: PANIC => panicked at C:\Users\A2va\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-import-0.17.0\src\burn\ty.rs:122:19:
"inception_resnet_v1/lambda_23/mul/y:0" is not a valid Ident
The onnx file was produced by tf2onnx which has come with this weird naming, so I wrote a python script to replace / by . and : by _. But an error similar to the previous one is also present
ERROR burn_import::logger: PANIC => panicked at C:\Users\A2va\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-import-0.17.0\src\burn\ty.rs:122:19:
"inception_resnet_v1.lambda_23.mul.y_0" is not a valid Ident
What's wrong with this naming ? Or is it maybe something else ?
Submitted a fix: #3208 (under a review)
I just encountered the same error with / but also . as invalid characters.