burn icon indicating copy to clipboard operation
burn copied to clipboard

ONNX import tests on wgpu backend

Open lucianyao opened this issue 4 months ago • 10 comments

Describe the bug When switching the backend from burn_ndarray::NdArray to burn_autodiff::Autodiff<burn_wgpu::Wgpu>, and compiling and running the tests under crates/burn-import/onnx-tests/, several ONNX import tests fail.

To Reproduce

  1. in crates/burn-import/onnx-tests/Cargo.toml, add:
[features]
backend-autodiff-wgpu = ["burn-autodiff", "burn-wgpu"]

[dependencies]
burn-autodiff = { path = "../../burn-autodiff", version = "0.18.0", default-features = false, optional = true }
burn-wgpu = { path = "../../burn-wgpu", version = "0.18.0", optional = true, default-features = false }
  1. in crates/burn-import/onnx-tests/tests/test_mod.rs, add:
#[cfg(feature = "backend-autodiff-wgpu")]
type Backend = burn_autodiff::Autodiff<burn_wgpu::Wgpu>;

#[cfg(not(feature = "backend-autodiff-wgpu"))]
type Backend = burn_ndarray::NdArray<f32>;
  1. under crates/burn-import/onnx-tests/tests/, replace:
type Backend = burn_ndarray::NdArray<f32>

with:

use super::super::Backend;

in the mod.rs files, by running script under crates/burn-import/onnx-tests/tests/:

find ./ -name mod.rs -exec sed -i '' 's|^[[:space:]]*type Backend = burn_ndarray::NdArray<f32>;|use super::super::Backend;|' {} +
  1. under crates/burn-import/onnx-tests/, run:
cargo test --test test_mod --features backend-autodiff-wgpu

got:

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:41:9
   |
41 |         assert!(f_output.equal(f_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:42:9
   |
42 |         assert!(i_output.equal(i_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:43:9
   |
43 |         assert!(b_output.equal(b_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

After commenting out pub mod constant_of_shape; in the test_mod.rs, run (4) again, got:

running 134 tests
test argmax::tests::argmax ... FAILED
test and::tests::and ... FAILED
test ceil::tests::ceil_test ... ok
test argmin::tests::argmin ... FAILED
test add::tests::add_scalar_to_int_tensor_and_int_tensor_to_int_tensor ... FAILED
......
test result: FAILED. 50 passed; 84 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.36s

Expected behavior All code pass compiling and testing.

Desktop (please complete the following information):

  • Device: Apple M1 Pro
  • OS: mac OS
  • Version: 14.2.1 (23C71)

Additional context The git repo version is:

commit b42a8b6556bf0f87077cdd0b2b5c0756f90ec8cb (HEAD -> main, origin/main, origin/HEAD)
Author: Jimmy Johnson <[email protected]>
Date:   Fri Jul 11 15:13:38 2025 +0200

    Updating documentation description for nonzero and nonzero_async (#3368)

lucianyao avatar Jul 11 '25 19:07 lucianyao

Interesting. @laggui, @nathanielsimard is this expected behavior?

antimora avatar Jul 11 '25 19:07 antimora

We should use the tensor data equality methods defined, not i_output.equal(i_expected).all().into_scalar().

Linking #3134 for reference as it points to the bool underlying data type discrepancy with wgpu backend.

Ofc this only addresses the compilation errors.

Not sure what the test failures are, but wouldn't be surprised if they were related to backend data types and conversions in the equality checks. I think burn-import tests have strong assumptions given the ndarray backend data types.

We should probably fix those 🙂

laggui avatar Jul 11 '25 19:07 laggui

[constant_of_shape] .into_scalar() returns u32, but assert! expects bool:

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:41:9
   |
41 |         assert!(f_output.equal(f_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:42:9
   |
42 |         assert!(i_output.equal(i_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:43:9
   |
43 |         assert!(b_output.equal(b_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

lucianyao avatar Jul 16 '25 19:07 lucianyao

Wgpu does not support "&&" for vectors

---- and::tests::and stdout ----

wgpu error: Validation Error

Caused by:
  In Device::create_shader_module

Shader '' parsing error: Incompatible operands: LogicalAnd(vec4<bool>, _)


      Incompatible operands: LogicalAnd(vec4<bool>, _)

lucianyao avatar Jul 16 '25 19:07 lucianyao

When the testing backend is NdArray, the integer type used is i64, while for the wgpu backend, the integer type is i32 (likely because wgpu does not support i64). This causes test failures due to mismatched integer types:

---- argmin::tests::argmin stdout ----

thread 'argmin::tests::argmin' panicked at crates/burn-import/onnx-tests/tests/argmin/mod.rs:26:26:
assertion `left == right` failed: Data types differ (I32 != I64)
  left: I32
 right: I64

---- argmax::tests::argmax stdout ----

thread 'argmax::tests::argmax' panicked at crates/burn-import/onnx-tests/tests/argmax/mod.rs:23:26:
assertion `left == right` failed: Data types differ (I32 != I64)
  left: I32
 right: I64

The similar errors are:

---- add::tests::add_scalar_to_int_tensor_and_int_tensor_to_int_tensor stdout ----

panicked at crates/burn-import/onnx-tests/tests/add/mod.rs:39:26:
assertion `left == right` failed: Data types differ (I32 != I64)
  left: I32
 right: I64

---- equal::tests::equal_scalar_to_scalar_and_tensor_to_tensor stdout ----

panicked at crates/burn-import/onnx-tests/tests/equal/mod.rs:25:30:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- greater::tests::greater stdout ----

panicked at crates/burn-import/onnx-tests/tests/greater/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- gather::tests::gather_shape stdout ----

panicked at crates/burn-import/onnx-tests/tests/gather/mod.rs:59:9:
assertion `left == right` failed
  left: TensorData { bytes: Bytes { data: [2, 0, 0, "..."], len: 4 }, shape: [1], dtype: I32 }
 right: TensorData { bytes: Bytes { data: [2, 0, 0, "..."], len: 8 }, shape: [1], dtype: I64 }

---- greater::tests::greater_scalar stdout ----

panicked at crates/burn-import/onnx-tests/tests/greater/mod.rs:37:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- greater_or_equal::tests::greater_or_equal_scalar stdout ----

panicked at crates/burn-import/onnx-tests/tests/greater_or_equal/mod.rs:38:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- greater_or_equal::tests::greater_or_equal stdout ----

panicked at crates/burn-import/onnx-tests/tests/greater_or_equal/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- less::tests::less stdout ----

panicked at crates/burn-import/onnx-tests/tests/less/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- less_or_equal::tests::less_or_equal stdout ----

panicked at crates/burn-import/onnx-tests/tests/less_or_equal/mod.rs:23:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- less::tests::less_scalar stdout ----

panicked at crates/burn-import/onnx-tests/tests/less/mod.rs:37:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

---- less_or_equal::tests::less_or_equal_scalar stdout ----

panicked at crates/burn-import/onnx-tests/tests/less_or_equal/mod.rs:37:26:
assertion `left == right` failed: Data types differ (U32 != Bool)
  left: U32
 right: Bool

lucianyao avatar Jul 16 '25 19:07 lucianyao

Reopening this issue — @laggui thanks again for merging the fix for the bool vs u32 mismatch!

Just noting that this issue originally covered several failing tests, and some of them (e.g., i64 vs i32 mismatches) seem unrelated to the bug that was just resolved. I’d like to verify whether the remaining failures can be addressed with a similar approach, and if not, I’ll work on separate fixes for them.

lucianyao avatar Jul 17 '25 11:07 lucianyao

Ahh yes the "Fixes [issue number]" pattern in a PR automatically closes the issue when the PR merges 😅

Reopening for the other related issues mentioned in the comments.

laggui avatar Jul 17 '25 11:07 laggui

Ahh got it — my bad 😅 I’ll break out future issues separately so they align better with the “Fixes #” pattern. Thanks for the heads-up!

lucianyao avatar Jul 17 '25 12:07 lucianyao

No worries! In this case I think it's fine to have a single issue, but you can just link it without closing it as long as you don't have "fixes/closes/resolves [issue number]" (e.g., just link the [issue number] without one of these prefixes).

laggui avatar Jul 17 '25 12:07 laggui

Received, thanks a lot for the clarification — really appreciate your help!

lucianyao avatar Jul 17 '25 12:07 lucianyao