luminal
luminal copied to clipboard
Vulkan support via Blade
Blade-graphics is a lightweight GPU abstraction layer with Vulkan, Metal, and GLES backends. For this PR, we are only interested in Vulkan side of things.
Shaders are composed in WGSL text. We could technically compose them in Naga IR directly, and that would be much fun, but the current Luminal code is better suited for text.
State of things
- can run part of "matmul" demo (with "--feature blade")
- no Vulkan validation errors
- stops at "TensorCore loop" part (to be fixed)
- all buffers are in shared memory (~~to be moved to private~~)
- no dynamic variables (~~to be supported~~)
- codegen has a bit of divergence (since WGSL is not C-like), and it can be lower
- the node-to-variable mapping is a bit of a hack
Results
Output on Linux: (doesn't look correct, need to look more into it)
FASTEST (27ms): Kernels: 1
Kernel 1 Grid: (512, 512, 1) Threadblock: (1, 1, 1) Smem: 0
var<storage, read> a: array<f32>; // GMEM(acc_0)
var<storage, read> b: array<f32>; // GMEM(A Load)
var<storage, read> c: array<f32>; // GMEM(B Load)
var<storage, read_write> d: array<f32>; // Output
@workgroup_size(1,1,1)
@compute fn main(
@builtin(workgroup_id) blockIdx: vec3<u32>,
@builtin(local_invocation_id) threadIdx: vec3<u32>,
) {
let loop_f = blockIdx.x;
let g = 0 + (512*u32(loop_f));
let h = 0 + (512*u32(loop_f));
let loop_i = blockIdx.y;
let j = 0 + u32(loop_i);
let k = h + u32(loop_i);
var l = array<f32, 1>();
for (var load = 0; load < 1; load+=1) {
l[0] = a[0];
}
for (var loop_m = 0; loop_m < 512; loop_m+=1) {
let n = g + u32(loop_m);
let o = j + (512*u32(loop_m));
let p = c[o] * b[n];
let q = p + l[0];
l[0] = q;
}
d[k] = l[0];
}
Outputs: [262144]
Valids: 1264 / 2165
[512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0]
Output on Mac:
FASTEST (46ms): Kernels: 2
Kernel 1 Grid: (512, 8, 1) Threadblock: (8, 8, 1) Smem: 0
var<storage, read> a: array<f32>; // GMEM(A Load)
var<storage, read> b: array<f32>; // GMEM(B Load)
var<storage, read_write> c: array<f32>; // Output
@workgroup_size(8,8,1)
@compute fn main(
@builtin(workgroup_id) blockIdx: vec3<u32>,
@builtin(local_invocation_id) threadIdx: vec3<u32>,
) {
let loop_e = blockIdx.x;
let f = 0 + (512*u32(loop_e));
let g = 0 + (u32(loop_e)*262144);
let loop_h = blockIdx.y;
let i = 0 + (64*u32(loop_h));
let j = g + (u32(loop_h)*32768);
let loop_k = threadIdx.x;
let l = i + (u32(loop_k)*8);
let m = j + (u32(loop_k)*4096);
let loop_n = threadIdx.y;
let o = l + u32(loop_n);
let p = m + (512*u32(loop_n));
for (var loop_q = 0; loop_q < 512; loop_q+=1) {
let r = f + (((u32(loop_q)/8)*8)+(u32(loop_q)%8));
let s = o + (((u32(loop_q)/8)*4096)+((u32(loop_q)%8)*512));
let t = p + (((u32(loop_q)/8)*8)+(u32(loop_q)%8));
let u = b[s] * a[r];
c[t] = u;
}
}
Outputs: [max(134217728, (8+(8*(511/8))))]
Kernel 2 Grid: (64, 8, 1) Threadblock: (512, 1, 1) Smem: 0
var<storage, read> a: array<f32>;
var<storage, read> b: array<f32>; // GMEM(acc_0)
var<storage, read_write> c: array<f32>; // Output
@workgroup_size(512,1,1)
@compute fn main(
@builtin(workgroup_id) blockIdx: vec3<u32>,
@builtin(local_invocation_id) threadIdx: vec3<u32>,
) {
let loop_e = blockIdx.x;
let f = 0 + (u32(loop_e)*2097152);
let g = 0 + (u32(loop_e)*4096);
let loop_h = blockIdx.y;
let i = f + (u32(loop_h)*262144);
let j = g + (512*u32(loop_h));
let loop_k = threadIdx.x;
let l = i + (512*u32(loop_k));
let m = j + u32(loop_k);
var n = array<f32, 1>();
for (var load = 0; load < 1; load+=1) {
n[0] = b[0];
}
for (var loop_o = 0; loop_o < 512; loop_o+=1) {
let p = l + u32(loop_o);
let q = n[0] + a[p];
n[0] = q;
}
c[m] = n[0];
}
Outputs: [262144]
Valids: 1264 / 2165
[512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0, 512.0]
Amazing work! Seems like you've listed out the remaining things to be done. How does the Var struct work? I've been meaning to have a more unified way to handle variables.
Var is replacing the old tuple of (index, is_pointer) with { index, pointer_owner }. The pointer_owner is Some(owner_index) if this value is a pointer pointing to some global binding. So, pointer_owner.is_some() is equivalent to the old is_pointer boolean. In addition, for backends that don't support pointer arithmetic (like Blade's WGSL), those "pointer-like" variables are just indices into the global bindings.
gotcha, thats nice. what would the pointer owner mean? like if the pointer is coming from an input, is it just the input's index?