burn
burn copied to clipboard
Batch matrix multiply leads to vulkan error on WGPU
I wasn't sure if batch matmul was supported, as this seems to be documented nowhere except in pytorch's documentation. It seems to work fine with small tensors but breaks down past a certain size:
type B = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
let a: Tensor<B, 4> = Tensor::random([500, 500, 4, 5], Distribution::Normal(-1.0, 1.0), &Default::default());
let b: Tensor<B, 4> = Tensor::random([500, 500, 5, 6], Distribution::Normal(-1.0, 1.0), &Default::default());
let out = a.matmul(b);
println!("{:?}", out);
I'd expect an output tensor of shape [500, 500, 4, 6]
, but instead I get the following error:
wgpu error: Validation Error
Caused by:
In a ComputePass
note: encoder = `<CommandBuffer-(1, 1, Vulkan)>`
In a dispatch command, indirect:false
note: compute pipeline = `<ComputePipeline-(4, 1, Vulkan)>`
Each current dispatch group size dimension ([1, 1, 250000]) must be less or equal to 65535
So it seems there's a maximal dimension of 65535 for bmm. I would expect that this backend-specific limitation be abstracted away, i.e the backend should likely batch the bmm and recombine them automatically. Is there a current workaround for this?
I'm using burn 0.13.2 with vulkan version 1.3.283 on fedora 40.