burn icon indicating copy to clipboard operation
burn copied to clipboard

Constant tensors in ONXX models don't get values populated

Open jameshiew opened this issue 11 months ago • 0 comments

Describe the bug Loading an ONNX model, if there is a constant tensor it doesn't seem to get populated with values. Not 100% this is a bug or I'm doing something wrong. I've been trying with models generated for testing.

To Reproduce I have a ONNX test in a branch on my fork - constant_tensor_f32. It's using an ONNX model generated from this PyTorch script - constant_tensor_f32.onnx, and trying to add a constant tensor [[2,2],[2,2]] to input [[0,0],[0,0]].

Running the ONNX tests (cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml) gives this error:

--- STDERR:              onnx-tests::test_onnx tests::constant_tensor_f32 ---
thread 'tests::constant_tensor_f32' panicked at crates/burn-import/onnx-tests/tests/test_onnx.rs:2242:9:
assertion `left == right` failed
  left: TensorData { bytes: [0, 0, 0, 64, 0, 0, 0, 64, 0, 0, 0, 64, 0, 0, 0, 64], shape: [2, 2], dtype: F32 }
 right: TensorData { bytes: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], shape: [2, 2], dtype: F32 }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Expected behavior The test should pass with expected output [[2,2],[2,2]]

Screenshots In Netron the model looks like it contains the [[2,2],[2,2]] constant tensor.

Screenshot 2024-12-17 at 17 43 11

Desktop (please complete the following information):

  • OS: macOS 15.2
  • Version: based off of https://github.com/tracel-ai/burn/commit/8a89293bf3ee02fe7216705ed3b7370506489e4a , ConstantNode functionality in the linked test branch shouldn't be different from that commit

Additional context

Running cargo run -p burn-import over that model, in the generated Rust code it looks like it is initializing the constant tensor to all zeros

        let constant1: burn::module::Param<Tensor<B, 2>> = burn::nn::Initializer::Zeros
            .init([2, 2], device)
            .set_require_grad(false);

jameshiew avatar Dec 17 '24 17:12 jameshiew