luminal
luminal copied to clipboard
Aggressive elementwise fusion
Currently the elementwise fusion is very conservative in what it fuses. It can be a lot more aggressive by:
- Fusing constants into kernels
- Fusing across shape changes and contiguous ops (by stacking views inside the kernel)
- Handling intermediate elementwise outputs that get used multiple times by downstream computation:
b = exp(a); // Intermediate
c = b + sin(b);
It should be possible to fuse this test down to a single kernel:
#[test]
fn test_fusion() {
let mut cx = Graph::new();
let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
let c = cx.constant(2.123);
let d = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
let mut out = ((a.exp2() - b.sin()).relu() + c.expand() / d).retrieve();
cx.execute();
let unopt_out = out.data();
out.drop();
cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut out);
cx.execute();
assert_close(&out.data(), &unopt_out);
}
As of 3e956f91e77c0e134c51553f1528a03e1acffa02
we can now fuse across arbitrary reshapes / contiguous ops. The other half of this is still to come: supporting common subexpressions internal to the kernel
As of 780951c8289693095eb56b4653d4b9353cf0b083
the fusion is good enough to remove the custom rope kernel! There's a slight performance disadvantage (17.8 tps custom vs 17.0 tps automatic) mainly due to not fusing in the subexpressions yet, but this approch demonstrates powerful kernel generation
What happens if we do tensor.slice(..1).cos().pad([0, 1])
? This causes the sliced out indexes to get passed through cos as 0, which means cos(0) = 1, so we get out 1s for the sliced out pieces, which are then padded back in.
Instead we need to insert indexing and valid expressions in between each and every component of the equation, not just stack them at the beginning.
As of 589148707f5c53f479703b709d663257a21f3760
this is no longer a problem, as I have rewritten the entire fusion to now use subexpressions and properly do valid checks. Each subexpression now does a valid check, not just a single valid check at the beninning. As a side benifit, the kernels are much easier to read!
Only remaining issue is mistral still outputs jibberish and runs slow for some reason when fusion is on.
Ok mistral is finally fixed. Current fusion is slower than before (~15 tps vs 17 tps) but it is correct, unlike before. The reason for the slowdown is almost definitely due to huge index and valid expressions. This will be solved in #47