burn
burn copied to clipboard
`RotaryEncoding` does not work for training on wgpu backend
Describe the bug
When trying to use RotarryEncoding with the wgpu backend in training I crash with an error every time. It fails on relatively small input sizes(context length=512, d_model=256), I did not test on very small inputs. Depending on build situation/build case the error changes but here's an example on latest main without autotune:
wgpu error: Validation Error
Caused by:
In ComputePass::end
In a dispatch command, indirect:false
Each current dispatch group size dimension ([1, 1, 81920]) must be less or equal to 65535
To Reproduce
Try to use RotaryEncoding on wgpu autodiff backend as part of a backwards step.
Desktop (please complete the following information):
- Windows 10
- RTX 4070
Can you share a MWE?
Looks like an operation might not be respecting the device limits somehow.
@laggui Here's a MWE.
main.rs
use burn::{
backend::Wgpu, nn::{loss::CrossEntropyLossConfig, RotaryEncoding, RotaryEncodingConfig}, prelude::*
};
fn main() {
type B = Wgpu;
let device = <B as Backend>::Device::default();
let context_length = 512;
let d_model = 256;
let n_heads = 4;
let batch_size = 128;
let d_k = d_model / n_heads;
let rope: RotaryEncoding<B> = RotaryEncodingConfig::new(context_length * 2, d_model / n_heads)
.with_theta(0.001)
.init(&device);
let data_test_x = Tensor::<_, 4, Float>::from_data(
TensorData::new(vec![0.0; batch_size * n_heads * context_length * d_k], Shape::new([batch_size, n_heads, context_length, d_k])),
&device,
);
let data_test_y = Tensor::<_, 1, Int>::from_data(
TensorData::new(vec![0; batch_size * context_length], Shape::new([batch_size * context_length])),
&device,
);
let x = rope.forward(data_test_x);
let x = x.swap_dims(1, 2).reshape([batch_size * context_length, n_heads * d_k]);
CrossEntropyLossConfig::new()
.init(&x.device())
.forward(x, data_test_y);
}
Cargo.toml
[package]
name = "gpt-burn"
version = "0.1.0"
edition = "2024"
[dependencies]
burn = { git="https://github.com/Tracel-AI/burn", version = "0.17", features = [
"fusion",
"ndarray",
"train",
"vision",
"wgpu",
"metrics",
] }
Error
thread 'main' panicked at C:\Users\impor\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\wgpu-25.0.0\src\backend\wgpu_core.rs:2879:26:
wgpu error: Validation Error
Caused by:
In ComputePass::end
In a dispatch command, indirect:false
Each current dispatch group size dimension ([1, 1, 262144]) must be less or equal to 65535
stack backtrace:
0: std::panicking::begin_panic_handler
at /rustc/4d91de4e48198da2e33413efdcd9cd2cc0c46688/library\std\src\panicking.rs:692
1: core::panicking::panic_fmt
at /rustc/4d91de4e48198da2e33413efdcd9cd2cc0c46688/library\core\src\panicking.rs:75
2: wgpu::backend::wgpu_core::ContextWgpuCore::handle_error_inner
3: <wgpu::backend::wgpu_core::CoreComputePass as wgpu::dispatch::ComputePassInterface>::end
4: core::ptr::drop_in_place<core::option::Option<wgpu::api::compute_pass::ComputePass>>
5: cubecl_wgpu::compute::stream::WgpuStream::flush
6: cubecl_wgpu::compute::stream::WgpuStream::start_profile
7: cubecl_runtime::client::ComputeClient<Server,Channel>::profile
8: <alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter
9: cubecl_runtime::tune::tune_benchmark::TuneBenchmark<S,C,In,Out>::profile
10: cubecl_runtime::tune::local::LocalTuner<AK,ID>::execute
11: burn_cubecl::kernel::matmul::tune::base::matmul_autotune
12: ZN11burn_cubecl3ops9float_ops197_$LT$impl$u20$burn_tensor..tensor..ops..tensor..FloatTensorOps$LT$burn_cubecl..backend..CubeBackend$LT$R$C$F$C$I$C$BT$GT$$GT$$u20$for$u20$burn_cubecl..backend..CubeBackend$LT$R$C$F$C$I$C$BT$GT$$GT$12float_matmul17h4b09ad50c3
13: ZN366_$LT$burn_fusion..ops..float..$LT$impl$u20$burn_tensor..tensor..ops..tensor..FloatTensorOps$LT$burn_fusion..backend..Fusion$LT$B$GT$$GT$$u20$for$u20$burn_fusion..backend..Fusion$LT$B$GT$$GT$..float_matmul..MatmulOps$LT$B$GT$$u20$as$u20$burn_fusion..st
14: burn_fusion::stream::execution::base::<impl burn_fusion::stream::base::OperationQueue<R>>::execute
15: burn_fusion::stream::execution::processor::Processor<O>::process
16: burn_fusion::stream::multi::MultiStream<R>::register
17: <burn_fusion::client::mutex::MutexFusionClient<R> as burn_fusion::client::base::FusionClient<R>>::register
18: burn_fusion::ops::float::<impl burn_tensor::tensor::ops::tensor::FloatTensorOps<burn_fusion::backend::Fusion<B>> for burn_fusion::backend::Fusion<B>>::float_reshape
19: burn_tensor::tensor::api::base::Tensor<B,_,K>::reshape
20: hashbrown::raw::inner::RawTableInner::drop_inner_table
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
error: process didn't exit successfully: `target\release\gpt-burn.exe` (exit code: 101)