burn
burn copied to clipboard
ONNX import tests on wgpu backend
Describe the bug
When switching the backend from burn_ndarray::NdArray
To Reproduce
- in
crates/burn-import/onnx-tests/Cargo.toml, add:
[features]
backend-autodiff-wgpu = ["burn-autodiff", "burn-wgpu"]
[dependencies]
burn-autodiff = { path = "../../burn-autodiff", version = "0.18.0", default-features = false, optional = true }
burn-wgpu = { path = "../../burn-wgpu", version = "0.18.0", optional = true, default-features = false }
- in
crates/burn-import/onnx-tests/tests/test_mod.rs, add:
#[cfg(feature = "backend-autodiff-wgpu")]
type Backend = burn_autodiff::Autodiff<burn_wgpu::Wgpu>;
#[cfg(not(feature = "backend-autodiff-wgpu"))]
type Backend = burn_ndarray::NdArray<f32>;
- under
crates/burn-import/onnx-tests/tests/, replace:
type Backend = burn_ndarray::NdArray<f32>
with:
use super::super::Backend;
in the mod.rs files, by running script under crates/burn-import/onnx-tests/tests/:
find ./ -name mod.rs -exec sed -i '' 's|^[[:space:]]*type Backend = burn_ndarray::NdArray<f32>;|use super::super::Backend;|' {} +
- under
crates/burn-import/onnx-tests/, run:
cargo test --test test_mod --features backend-autodiff-wgpu
got:
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:41:9
|
41 | assert!(f_output.equal(f_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:42:9
|
42 | assert!(i_output.equal(i_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:43:9
|
43 | assert!(b_output.equal(b_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
After commenting out pub mod constant_of_shape; in the test_mod.rs, run (4) again, got:
running 134 tests
test argmax::tests::argmax ... FAILED
test and::tests::and ... FAILED
test ceil::tests::ceil_test ... ok
test argmin::tests::argmin ... FAILED
test add::tests::add_scalar_to_int_tensor_and_int_tensor_to_int_tensor ... FAILED
......
test result: FAILED. 50 passed; 84 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.36s
Expected behavior All code pass compiling and testing.
Desktop (please complete the following information):
- Device: Apple M1 Pro
- OS: mac OS
- Version: 14.2.1 (23C71)
Additional context The git repo version is:
commit b42a8b6556bf0f87077cdd0b2b5c0756f90ec8cb (HEAD -> main, origin/main, origin/HEAD)
Author: Jimmy Johnson <[email protected]>
Date: Fri Jul 11 15:13:38 2025 +0200
Updating documentation description for nonzero and nonzero_async (#3368)
Interesting. @laggui, @nathanielsimard is this expected behavior?
We should use the tensor data equality methods defined, not i_output.equal(i_expected).all().into_scalar().
Linking #3134 for reference as it points to the bool underlying data type discrepancy with wgpu backend.
Ofc this only addresses the compilation errors.
Not sure what the test failures are, but wouldn't be surprised if they were related to backend data types and conversions in the equality checks. I think burn-import tests have strong assumptions given the ndarray backend data types.
We should probably fix those 🙂
[constant_of_shape] .into_scalar() returns u32, but assert! expects bool:
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:41:9
|
41 | assert!(f_output.equal(f_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:42:9
|
42 | assert!(i_output.equal(i_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:43:9
|
43 | assert!(b_output.equal(b_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
Wgpu does not support "&&" for vectors
---- and::tests::and stdout ----
wgpu error: Validation Error
Caused by:
In Device::create_shader_module
Shader '' parsing error: Incompatible operands: LogicalAnd(vec4<bool>, _)
Incompatible operands: LogicalAnd(vec4<bool>, _)
When the testing backend is NdArray, the integer type used is i64, while for the wgpu backend, the integer type is i32 (likely because wgpu does not support i64). This causes test failures due to mismatched integer types:
---- argmin::tests::argmin stdout ----
thread 'argmin::tests::argmin' panicked at crates/burn-import/onnx-tests/tests/argmin/mod.rs:26:26:
assertion `left == right` failed: Data types differ (I32 != I64)
left: I32
right: I64
---- argmax::tests::argmax stdout ----
thread 'argmax::tests::argmax' panicked at crates/burn-import/onnx-tests/tests/argmax/mod.rs:23:26:
assertion `left == right` failed: Data types differ (I32 != I64)
left: I32
right: I64
The similar errors are:
---- add::tests::add_scalar_to_int_tensor_and_int_tensor_to_int_tensor stdout ----
panicked at crates/burn-import/onnx-tests/tests/add/mod.rs:39:26:
assertion `left == right` failed: Data types differ (I32 != I64)
left: I32
right: I64
---- equal::tests::equal_scalar_to_scalar_and_tensor_to_tensor stdout ----
panicked at crates/burn-import/onnx-tests/tests/equal/mod.rs:25:30:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- greater::tests::greater stdout ----
panicked at crates/burn-import/onnx-tests/tests/greater/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- gather::tests::gather_shape stdout ----
panicked at crates/burn-import/onnx-tests/tests/gather/mod.rs:59:9:
assertion `left == right` failed
left: TensorData { bytes: Bytes { data: [2, 0, 0, "..."], len: 4 }, shape: [1], dtype: I32 }
right: TensorData { bytes: Bytes { data: [2, 0, 0, "..."], len: 8 }, shape: [1], dtype: I64 }
---- greater::tests::greater_scalar stdout ----
panicked at crates/burn-import/onnx-tests/tests/greater/mod.rs:37:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- greater_or_equal::tests::greater_or_equal_scalar stdout ----
panicked at crates/burn-import/onnx-tests/tests/greater_or_equal/mod.rs:38:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- greater_or_equal::tests::greater_or_equal stdout ----
panicked at crates/burn-import/onnx-tests/tests/greater_or_equal/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- less::tests::less stdout ----
panicked at crates/burn-import/onnx-tests/tests/less/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- less_or_equal::tests::less_or_equal stdout ----
panicked at crates/burn-import/onnx-tests/tests/less_or_equal/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- less::tests::less_scalar stdout ----
panicked at crates/burn-import/onnx-tests/tests/less/mod.rs:37:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
---- less_or_equal::tests::less_or_equal_scalar stdout ----
panicked at crates/burn-import/onnx-tests/tests/less_or_equal/mod.rs:37:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
left: U32
right: Bool
Reopening this issue — @laggui thanks again for merging the fix for the bool vs u32 mismatch!
Just noting that this issue originally covered several failing tests, and some of them (e.g., i64 vs i32 mismatches) seem unrelated to the bug that was just resolved. I’d like to verify whether the remaining failures can be addressed with a similar approach, and if not, I’ll work on separate fixes for them.
Ahh yes the "Fixes [issue number]" pattern in a PR automatically closes the issue when the PR merges 😅
Reopening for the other related issues mentioned in the comments.
Ahh got it — my bad 😅 I’ll break out future issues separately so they align better with the “Fixes #” pattern. Thanks for the heads-up!
No worries! In this case I think it's fine to have a single issue, but you can just link it without closing it as long as you don't have "fixes/closes/resolves [issue number]" (e.g., just link the [issue number] without one of these prefixes).
Received, thanks a lot for the clarification — really appreciate your help!