burn
burn copied to clipboard
Cpu/Cuda conversion issue in Candle backend during batchnorm layer
Hi there,
burn = { version = "0.11.1", default-features = false, features = ["train", "ndarray", "cuda", "candle"] }
Gave this error:
thread 'main' panicked at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.11.1/src/tensor/api/numeric.rs:41:9:
=== Tensor Operation Error ===
Operation: 'Add'
Reason:
1. The provided tensors are not on the same device. Lhs tensor device Cpu, Rhs tensor device Cuda(0).
stack backtrace:
0: rust_begin_unwind
at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/std/src/panicking.rs:597:5
1: core::panicking::panic_fmt
at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panicking.rs:72:14
2: core::panicking::panic_display
at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panicking.rs:178:5
3: burn_tensor::tensor::api::float::<impl burn_tensor::tensor::api::base::Tensor<B,_>>::matmul::panic_cold_display
at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panic.rs:99:13
4: burn_tensor::tensor::api::numeric::<impl burn_tensor::tensor::api::base::Tensor<B,_,K>>::add
at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.11.1/src/tensor/api/numeric.rs:41:9
5: burn_core::nn::norm::batch::BatchNorm<B,_>::forward_train
at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.11.1/src/nn/norm/batch.rs:133:28
6: burn_core::nn::norm::batch::BatchNorm<B,_>::forward
at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.11.1/src/nn/norm/batch.rs:85:21
...
It seems to be raised from this statement:
let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
mean.clone()
.detach()
.mul_scalar(self.momentum)
.reshape([channels]),
);
This looks like the kind of bug that occurs because some methods always use the default device instead of the actual one... Default device stuff is under heavy refactoring in #1081 so I suggest we wait for it to merge
https://github.com/tracel-ai/burn/pull/1081 has been merged and I assume this will fix. Closing for now. Let us know if this problem still occurs.