burn icon indicating copy to clipboard operation
burn copied to clipboard

NaN related ops are not correct for burn-wgpu fusion backend

Open antimora opened this issue 1 year ago • 8 comments

I am adding two new tensor methods related to NaN (is_nan and contains_nan). My unit tests are failing for burn-wgpu backend. See https://github.com/tracel-ai/burn/pull/2088 PR

Describe the bug These two methods do not work as expected:

    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
    ///
    /// # Returns
    ///
    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
    pub fn is_nan(&self) -> Tensor<B, D, Bool> {
        // Check if the input tensor is NaN by comparing it to itself
        // NaN is the only value that is not equal to itself
        K::not_equal(self.primitive.clone(), self.primitive.clone())
    }

    /// Checks if the tensor contains any NaN values.
    ///
    /// # Returns
    ///
    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
    pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
        // Summing the tensor will result in NaN if the tensor contains any NaN values
        // This is faster than checking each element individually
        // because it rolls up the NaN values into a single value
        let sum = K::sum(self.primitive.clone());

        // Check if the sum is NaN by comparing it to itself
        K::not_equal(sum.clone(), sum)
    }

To Reproduce

  1. check out this PR branch: https://github.com/tracel-ai/burn/pull/2088 PR
  2. cd burn-wgpu
  3. cargo test nan

Expected behavior The tests should pass.

Screenshots image

NOTE: The unit test should be enabled currently ignored for all backends.

antimora avatar Aug 01 '24 16:08 antimora

@nathanielsimard @louisfd, this is only for fusion backend. Any particular difference?

antimora avatar Aug 03 '24 18:08 antimora

For anyone coming here trying to find a solution, I found one in tensor.greater_equal_elem(f64::NEG_INFINITY).bool_not() for tensor.is_nan(), similar solution for tensor.any_nan().

ImTheSquid avatar Nov 05 '24 20:11 ImTheSquid

@antimora Fusion has changed a lot since the issue was openned. We should check if this is fixed.

nathanielsimard avatar Nov 18 '24 16:11 nathanielsimard

@antimora Fusion has changed a lot since the issue was openned. We should check if this is fixed.

OK. I'll check it.

antimora avatar Nov 18 '24 16:11 antimora

Just checked on main, still fails

test tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan ... FAILED
test tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan ... FAILED

failures:

---- tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan stdout ----
thread 'tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan' panicked at crates/burn-wgpu/src/lib.rs:108:5:
assertion failed: with_nan.contains_nan().into_scalar()
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

---- tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan stdout ----
thread 'tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan' panicked at crates/burn-wgpu/src/lib.rs:108:5:
assertion `left == right` failed
  left: TensorData { bytes: [0, 1, 0, 1, 0, 0], shape: [2, 3], dtype: Bool }
 right: TensorData { bytes: [0, 0, 0, 0, 0, 0], shape: [2, 3], dtype: Bool }


failures:
    tests::jit_fusion::tensor::f32_ty::nan::tests::contains_nan
    tests::jit_fusion::tensor::f32_ty::nan::tests::is_nan

test result: FAILED. 12 passed; 2 failed; 0 ignored; 0 measured; 2036 filtered out; finished in 1.34s

error: test failed, to rerun pass `--lib`

ImTheSquid avatar Nov 21 '24 17:11 ImTheSquid

@ImTheSquid thanks for testing it again.

antimora avatar Nov 21 '24 18:11 antimora

What's the reasoning behind ignoring the tests when they don't work? There are other tests that rely on is_nan and contains_nan working properly, and if they don't work then a lot of other functions suddenly come under suspicion as well.

ImTheSquid avatar Nov 21 '24 21:11 ImTheSquid

What's the reasoning behind ignoring the tests when they don't work? There are other tests that rely on is_nan and contains_nan working properly, and if they don't work then a lot of other functions suddenly come under suspicion as well.

Mainly to move faster. At the time, the wgpu was being reworked completely and we couldn't just fix the issue.

We should bump the urgency to fix it. We intentionally ignored the tests knowing it will be fixed quickly.

@nathanielsimard should we take this up soon?

antimora avatar Nov 25 '24 17:11 antimora