burn
burn copied to clipboard
NaN related ops are not correct for burn-wgpu fusion backend
I am adding two new tensor methods related to NaN (is_nan and contains_nan). My unit tests are failing for burn-wgpu backend. See https://github.com/tracel-ai/burn/pull/2088 PR
Describe the bug These two methods do not work as expected:
/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
///
/// # Returns
///
/// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
pub fn is_nan(&self) -> Tensor<B, D, Bool> {
// Check if the input tensor is NaN by comparing it to itself
// NaN is the only value that is not equal to itself
K::not_equal(self.primitive.clone(), self.primitive.clone())
}
/// Checks if the tensor contains any NaN values.
///
/// # Returns
///
/// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
// Summing the tensor will result in NaN if the tensor contains any NaN values
// This is faster than checking each element individually
// because it rolls up the NaN values into a single value
let sum = K::sum(self.primitive.clone());
// Check if the sum is NaN by comparing it to itself
K::not_equal(sum.clone(), sum)
}
To Reproduce
- check out this PR branch: https://github.com/tracel-ai/burn/pull/2088 PR
- cd burn-wgpu
- cargo test nan
Expected behavior The tests should pass.
Screenshots
NOTE: The unit test should be enabled currently ignored for all backends.
@nathanielsimard @louisfd, this is only for fusion backend. Any particular difference?
For anyone coming here trying to find a solution, I found one in tensor.greater_equal_elem(f64::NEG_INFINITY).bool_not() for tensor.is_nan(), similar solution for tensor.any_nan().
@antimora Fusion has changed a lot since the issue was openned. We should check if this is fixed.
@antimora Fusion has changed a lot since the issue was openned. We should check if this is fixed.
OK. I'll check it.
Just checked on main, still fails
test tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan ... FAILED
test tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan ... FAILED
failures:
---- tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan stdout ----
thread 'tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan' panicked at crates/burn-wgpu/src/lib.rs:108:5:
assertion failed: with_nan.contains_nan().into_scalar()
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
---- tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan stdout ----
thread 'tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan' panicked at crates/burn-wgpu/src/lib.rs:108:5:
assertion `left == right` failed
left: TensorData { bytes: [0, 1, 0, 1, 0, 0], shape: [2, 3], dtype: Bool }
right: TensorData { bytes: [0, 0, 0, 0, 0, 0], shape: [2, 3], dtype: Bool }
failures:
tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan
tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan
test result: FAILED. 12 passed; 2 failed; 0 ignored; 0 measured; 2036 filtered out; finished in 1.34s
error: test failed, to rerun pass `--lib`
@ImTheSquid thanks for testing it again.
What's the reasoning behind ignoring the tests when they don't work? There are other tests that rely on is_nan and contains_nan working properly, and if they don't work then a lot of other functions suddenly come under suspicion as well.
What's the reasoning behind ignoring the tests when they don't work? There are other tests that rely on
is_nanandcontains_nanworking properly, and if they don't work then a lot of other functions suddenly come under suspicion as well.
Mainly to move faster. At the time, the wgpu was being reworked completely and we couldn't just fix the issue.
We should bump the urgency to fix it. We intentionally ignored the tests knowing it will be fixed quickly.
@nathanielsimard should we take this up soon?