luminal
luminal copied to clipboard
Elementwise fusion causes a.max(b) to fail on tensors larger than 3
test_max fails in metal because of this I believe. Looking at the generated kernel, it doesn't seem correct:
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device float* input1 [[buffer(1)]], device float* input0 [[buffer(0)]], device float* input2 [[buffer(2)]], device float *out [[buffer(3)]], device uint& n_elements [[buffer(4)]], uint idx [[thread_position_in_grid]]) {
if (idx < n_elements) {
out[idx] = ((input2[idx] < input0[idx]) * ((input2[idx] < input0[idx]))) + ((input1[0] - (input2[idx] < input0[idx])) * input2[idx]);
}
}
Also I'm going to write a dedicated max metal op which will solve this as well, but it's important to fix the elementwise fusion first.
This has been fixed now with the improved elementwise fusion #39