luminal icon indicating copy to clipboard operation
luminal copied to clipboard

Vulkan support via Blade

Open kvark opened this issue 2 months ago • 3 comments

Blade-graphics is a lightweight GPU abstraction layer with Vulkan, Metal, and GLES backends. For this PR, we are only interested in Vulkan side of things.

Shaders are composed in WGSL text. We could technically compose them in Naga IR directly, and that would be much fun, but the current Luminal code is better suited for text.

State of things

  • can run part of "matmul" demo (with "--feature blade")
  • no Vulkan validation errors
  • stops at "TensorCore loop" part (to be fixed)
  • all buffers are in shared memory (~~to be moved to private~~)
  • no dynamic variables (~~to be supported~~)
  • codegen has a bit of divergence (since WGSL is not C-like), and it can be lower
  • the node-to-variable mapping is a bit of a hack

Results

Luminal-blade0

Output on Linux: (doesn't look correct, need to look more into it)

FASTEST (27ms): Kernels: 1
Kernel 1 Grid: (512, 512, 1) Threadblock: (1, 1, 1) Smem: 0
var<storage, read> a: array<f32>;  // GMEM(acc_0)
var<storage, read> b: array<f32>;  // GMEM(A Load)
var<storage, read> c: array<f32>;  // GMEM(B Load)
var<storage, read_write> d: array<f32>; // Output

@workgroup_size(1,1,1)
@compute fn main(
        @builtin(workgroup_id) blockIdx: vec3<u32>,
        @builtin(local_invocation_id) threadIdx: vec3<u32>,
) {
        let loop_f = blockIdx.x;
        let g = 0 + (512*u32(loop_f));
        let h = 0 + (512*u32(loop_f));
        let loop_i = blockIdx.y;
        let j = 0 + u32(loop_i);
        let k = h + u32(loop_i);
        var l = array<f32, 1>();
        for (var load = 0; load < 1; load+=1) {
                l[0] = a[0];
        }
        for (var loop_m = 0; loop_m < 512; loop_m+=1) {
                let n = g + u32(loop_m);
                let o = j + (512*u32(loop_m));
                let p = c[o] * b[n];
                let q = p + l[0];
                l[0] = q;
        }
        d[k] = l[0];
}
Outputs: [262144]
Valids: 1264 / 2165
[512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0]

Output on Mac:

FASTEST (46ms): Kernels: 2
Kernel 1 Grid: (512, 8, 1) Threadblock: (8, 8, 1) Smem: 0
var<storage, read> a: array<f32>;  // GMEM(A Load)
var<storage, read> b: array<f32>;  // GMEM(B Load)
var<storage, read_write> c: array<f32>; // Output

@workgroup_size(8,8,1)
@compute fn main(
        @builtin(workgroup_id) blockIdx: vec3<u32>,
        @builtin(local_invocation_id) threadIdx: vec3<u32>,
) {
        let loop_e = blockIdx.x;
        let f = 0 + (512*u32(loop_e));
        let g = 0 + (u32(loop_e)*262144);
        let loop_h = blockIdx.y;
        let i = 0 + (64*u32(loop_h));
        let j = g + (u32(loop_h)*32768);
        let loop_k = threadIdx.x;
        let l = i + (u32(loop_k)*8);
        let m = j + (u32(loop_k)*4096);
        let loop_n = threadIdx.y;
        let o = l + u32(loop_n);
        let p = m + (512*u32(loop_n));
        for (var loop_q = 0; loop_q < 512; loop_q+=1) {
                let r = f + (((u32(loop_q)/8)*8)+(u32(loop_q)%8));
                let s = o + (((u32(loop_q)/8)*4096)+((u32(loop_q)%8)*512));
                let t = p + (((u32(loop_q)/8)*8)+(u32(loop_q)%8));
                let u = b[s] * a[r];
                c[t] = u;
        }
}
Outputs: [max(134217728, (8+(8*(511/8))))]
Kernel 2 Grid: (64, 8, 1) Threadblock: (512, 1, 1) Smem: 0
var<storage, read> a: array<f32>; 
var<storage, read> b: array<f32>;  // GMEM(acc_0)
var<storage, read_write> c: array<f32>; // Output

@workgroup_size(512,1,1)
@compute fn main(
        @builtin(workgroup_id) blockIdx: vec3<u32>,
        @builtin(local_invocation_id) threadIdx: vec3<u32>,
) {
        let loop_e = blockIdx.x;
        let f = 0 + (u32(loop_e)*2097152);
        let g = 0 + (u32(loop_e)*4096);
        let loop_h = blockIdx.y;
        let i = f + (u32(loop_h)*262144);
        let j = g + (512*u32(loop_h));
        let loop_k = threadIdx.x;
        let l = i + (512*u32(loop_k));
        let m = j + u32(loop_k);
        var n = array<f32, 1>();
        for (var load = 0; load < 1; load+=1) {
                n[0] = b[0];
        }
        for (var loop_o = 0; loop_o < 512; loop_o+=1) {
                let p = l + u32(loop_o);
                let q = n[0] + a[p];
                n[0] = q;
        }
        c[m] = n[0];
}
Outputs: [262144]
Valids: 1264 / 2165
[512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0]

kvark avatar Sep 12 '25 06:09 kvark

Amazing work! Seems like you've listed out the remaining things to be done. How does the Var struct work? I've been meaning to have a more unified way to handle variables.

jafioti avatar Sep 12 '25 14:09 jafioti

Var is replacing the old tuple of (index, is_pointer) with { index, pointer_owner }. The pointer_owner is Some(owner_index) if this value is a pointer pointing to some global binding. So, pointer_owner.is_some() is equivalent to the old is_pointer boolean. In addition, for backends that don't support pointer arithmetic (like Blade's WGSL), those "pointer-like" variables are just indices into the global bindings.

kvark avatar Sep 13 '25 04:09 kvark

gotcha, thats nice. what would the pointer owner mean? like if the pointer is coming from an input, is it just the input's index?

jafioti avatar Sep 13 '25 06:09 jafioti