burn icon indicating copy to clipboard operation
burn copied to clipboard

Batch matrix multiply leads to vulkan error on WGPU

Open jungerm2 opened this issue 8 months ago • 2 comments

I wasn't sure if batch matmul was supported, as this seems to be documented nowhere except in pytorch's documentation. It seems to work fine with small tensors but breaks down past a certain size:

type B = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
let a: Tensor<B, 4> = Tensor::random([500, 500, 4, 5], Distribution::Normal(-1.0, 1.0), &Default::default());
let b: Tensor<B, 4> = Tensor::random([500, 500, 5, 6], Distribution::Normal(-1.0, 1.0), &Default::default());
let out = a.matmul(b);
println!("{:?}", out);

I'd expect an output tensor of shape [500, 500, 4, 6], but instead I get the following error:

wgpu error: Validation Error

Caused by:
    In a ComputePass
      note: encoder = `<CommandBuffer-(1, 1, Vulkan)>`
    In a dispatch command, indirect:false
      note: compute pipeline = `<ComputePipeline-(4, 1, Vulkan)>`
    Each current dispatch group size dimension ([1, 1, 250000]) must be less or equal to 65535

So it seems there's a maximal dimension of 65535 for bmm. I would expect that this backend-specific limitation be abstracted away, i.e the backend should likely batch the bmm and recombine them automatically. Is there a current workaround for this?

I'm using burn 0.13.2 with vulkan version 1.3.283 on fedora 40.

jungerm2 avatar Jun 07 '24 14:06 jungerm2