candle
candle copied to clipboard
Tensor shape mismatch error when using candle-onnx with audio input
Description
I'm encountering a tensor reshape error when trying to run inference on an audio model using candle-onnx. The error occurs during model evaluation despite the input tensor seemingly having the correct shape.
Reproduction Code
let model = candle_onnx::read_file(MODEL).unwrap();
let graph = model.graph.as_ref().unwrap();
let input_tensor = Tensor::from_vec(audio[0..32000].to_vec(), (1, 32000), &Device::Cpu).unwrap();
let mut inputs: std::collections::HashMap<String, Tensor> = std::collections::HashMap::new();
inputs.insert(graph.input[0].name.to_string(), input_tensor);
let mut outputs = candle_onnx::simple_eval(&model, inputs).unwrap();
let output = outputs.remove(&graph.output[0].name).unwrap();
Input Information
When printing the model input information:
Input dims: ValueInfoProto {
name: "input",
r#type: Some(TypeProto {
denotation: "",
value: Some(TensorType(Tensor {
elem_type: 1,
shape: Some(TensorShapeProto {
dim: [
Dimension { denotation: "", value: Some(DimValue(1)) },
Dimension { denotation: "", value: Some(DimValue(32000)) }
]
})
}))
}),
doc_string: ""
}
Input Tensor shape: Tensor[dims 1, 32000; f32]
Error Message
thread 'main' panicked at src/main.rs:61:64:
called `Result::unwrap()` on an `Err` value: shape mismatch in reshape, lhs: [3], rhs: [1, 2]
Environment
- Rust version: 1.85.1
- candle-onnx version: 0.8.4
- OS: Linux
Additional Information
- The error happens during the
simple_evalcall, not during tensor creation - The input tensor has the correct shape
[1, 32000]as expected by the model - The audio data is loaded from a WAV file
- The error seems to be from an internal reshaping operation where something is trying to reshape a tensor of shape [3] to [1, 2]
Any help or suggestions would be greatly appreciated!
I would suggest running the code with RUST_BACKTRACE=1 and with a cargo profile that includes debug symbols (e.g. in the candle crate profile=release-with-debug. This should provide a backtrace that will hopefully give you more details about why this reshape is called in the onnx model.
Thank you for your feedback @LaurentMazare,
While debugging the model execution issue, I added logging statements and noticed what seems to be a bug in the shape produced by a Concat operation. Below is some relevant information from the computation graph and logs
Context
The error is raised just after a Concat operation, during a Reshape. I've highlighted the operation in the graph with a blue arrow in the image below.
Logs
// In Concat operation
Input0 in Concat: Tensor[512, 512; i64]
Output in Concat: Tensor[512, 512, 0; i64]
// In Reshape operation
node ["/base/spectrogram_extractor/Concat_output_0", "/base/spectrogram_extractor/Constant_2_output_0"]
/base/spectrogram_extractor/Concat_output_0 , Input shape [3] and Tensor[512, 512, 0; i64]
/base/spectrogram_extractor/Constant_2_output_0, Output shape [-1, 2]
It looks like the Concat is unexpectedly concatenating 0 to the output tensor, which changes its tensor from [512, 512] to [512, 512, 0]. This leads to a shape mismatch or failure in the downstream Reshape operation.
Let me know if you need more context or if you'd like me to run additional diagnostics!
I also ran into this issue. I think the error might be in the implementation of the ConstantOfShape operation. It was using the shape of the input vector rather than the input vector itself to determine the shape of the new vector.
It appears to have been fixed recently in the main branch (after 0.9.1) and the tests look correct now.
I also ran into this issue. I think the error might be in the implementation of the ConstantOfShape operation. It was using the shape of the input vector rather than the input vector itself to determine the shape of the new vector.
It appears to have been fixed recently in the main branch (after 0.9.1) and the tests look correct now.
Thanks for your feedback. I will retry it.