burn icon indicating copy to clipboard operation
burn copied to clipboard

Cpu/Cuda conversion issue in Candle backend during batchnorm layer

Open michael-temp opened this issue 1 year ago • 1 comments

Hi there,

burn = { version = "0.11.1", default-features = false, features = ["train", "ndarray", "cuda", "candle"] }

Gave this error:

thread 'main' panicked at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.11.1/src/tensor/api/numeric.rs:41:9:
=== Tensor Operation Error ===
  Operation: 'Add'
  Reason:
    1. The provided tensors are not on the same device. Lhs tensor device Cpu, Rhs tensor device Cuda(0). 

stack backtrace:
   0: rust_begin_unwind
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/std/src/panicking.rs:597:5
   1: core::panicking::panic_fmt
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panicking.rs:72:14
   2: core::panicking::panic_display
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panicking.rs:178:5
   3: burn_tensor::tensor::api::float::<impl burn_tensor::tensor::api::base::Tensor<B,_>>::matmul::panic_cold_display
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panic.rs:99:13
   4: burn_tensor::tensor::api::numeric::<impl burn_tensor::tensor::api::base::Tensor<B,_,K>>::add
             at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.11.1/src/tensor/api/numeric.rs:41:9
   5: burn_core::nn::norm::batch::BatchNorm<B,_>::forward_train
             at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.11.1/src/nn/norm/batch.rs:133:28
   6: burn_core::nn::norm::batch::BatchNorm<B,_>::forward
             at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.11.1/src/nn/norm/batch.rs:85:21
   ...

It seems to be raised from this statement:

let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
    mean.clone()
        .detach()
        .mul_scalar(self.momentum)
        .reshape([channels]),
);

michael-temp avatar Dec 13 '23 09:12 michael-temp

This looks like the kind of bug that occurs because some methods always use the default device instead of the actual one... Default device stuff is under heavy refactoring in #1081 so I suggest we wait for it to merge

louisfd avatar Dec 19 '23 19:12 louisfd

https://github.com/tracel-ai/burn/pull/1081 has been merged and I assume this will fix. Closing for now. Let us know if this problem still occurs.

antimora avatar Mar 01 '24 17:03 antimora