burn
burn copied to clipboard
NaN for `[N, C, 1, W]` interpolate inputs with bilinear and bicubic modes for burn-wgpu and burn-ndarray backends
Describe the bug
NaN values when the input tensor is [N, C, 1, W] shape.
To Reproduce
- Enable
test_1d_bilineartest incrates/burn-tensor/src/tests/module/bilinear_interpolate.rsandtest_1d_bicubicincrates/burn-tensor/src/tests/module/bilinear_interpolate.rswhich are ignored for now. (Copy of tests are in the comments) - cd burn-wgpu && cargo test
- cd burn-ndarray && cargo test
- Enable
resize_with_scales_1d_lineartest incrates/burn-import/onnx-tests/tests/onnx_tests.rswhich is ignored.
NOTE:
These tests work under burn-tch and pytorch.
Related Issues:
- #1246
- #2081
Ignored tests just in case the code gets deleted:
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
fn test_1d_bicubic() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
// Run the model
let input = TestTensor::<3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
let input = input.unsqueeze_dim(2);
let output = interpolate(
input,
[1, 9],
InterpolateOptions::new(InterpolateMode::Bicubic),
);
assert_eq!(output.dims(), [1, 1, 1, 9]);
// assert output data does not contain NaN
assert!(
!output
.clone()
.to_data()
.as_slice::<f32>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
"interpolate output contains NaN"
);
TestTensor::<4>::from([[[[
1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794,
-1.3986,
]]]])
.to_data()
.assert_approx_eq(&output.into_data(), 3);
}
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
fn test_1d_bilinear() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
// Run the model
let input = TestTensor::<3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
let input = input.unsqueeze_dim(2);
let output = interpolate(
input,
[1, 9],
InterpolateOptions::new(InterpolateMode::Bilinear),
);
assert_eq!(output.dims(), [1, 1, 1, 9]);
// assert output data does not contain NaN
assert!(
!output
.clone()
.to_data()
.as_slice::<f32>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
"interpolate output contains NaN"
);
TestTensor::<4>::from([[[[
1.541f32,
0.39450002,
-0.76475,
-1.943125,
-0.80520004,
0.36178753,
-0.671275,
-1.2022874,
-1.3986,
]]]])
.to_data()
.assert_approx_eq(&output.into_data(), 3);
}
CC @laggui @louisfd