luminal icon indicating copy to clipboard operation
luminal copied to clipboard

Elementwise fusion causes a.max(b) to fail on tensors larger than 3

Open jafioti opened this issue 11 months ago • 1 comments

test_max fails in metal because of this I believe. Looking at the generated kernel, it doesn't seem correct:

#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device float* input1 [[buffer(1)]], device float* input0 [[buffer(0)]], device float* input2 [[buffer(2)]], device float *out [[buffer(3)]], device uint& n_elements [[buffer(4)]], uint idx [[thread_position_in_grid]]) {
    if (idx < n_elements) {
        out[idx] = ((input2[idx] < input0[idx]) * ((input2[idx] < input0[idx]))) + ((input1[0] - (input2[idx] < input0[idx])) * input2[idx]);
    }
}

jafioti avatar Mar 02 '24 21:03 jafioti

Also I'm going to write a dedicated max metal op which will solve this as well, but it's important to fix the elementwise fusion first.

jafioti avatar Mar 02 '24 21:03 jafioti

This has been fixed now with the improved elementwise fusion #39

jafioti avatar Apr 09 '24 23:04 jafioti