luminal icon indicating copy to clipboard operation
luminal copied to clipboard

Aggressive elementwise fusion

Open jafioti opened this issue 11 months ago • 1 comments

Currently the elementwise fusion is very conservative in what it fuses. It can be a lot more aggressive by:

  • Fusing constants into kernels
  • Fusing across shape changes and contiguous ops (by stacking views inside the kernel)
  • Handling intermediate elementwise outputs that get used multiple times by downstream computation:
b = exp(a); // Intermediate
c = b + sin(b);

It should be possible to fuse this test down to a single kernel:

#[test]
fn test_fusion() {
    let mut cx = Graph::new();
    let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
    let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
    let c = cx.constant(2.123);
    let d = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
    let mut out = ((a.exp2() - b.sin()).relu() + c.expand() / d).retrieve();

    cx.execute();
    let unopt_out = out.data();
    out.drop();

    cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut out);
    cx.execute();

    assert_close(&out.data(), &unopt_out);
}

jafioti avatar Mar 06 '24 19:03 jafioti

As of 3e956f91e77c0e134c51553f1528a03e1acffa02 we can now fuse across arbitrary reshapes / contiguous ops. The other half of this is still to come: supporting common subexpressions internal to the kernel

jafioti avatar Mar 11 '24 20:03 jafioti

As of 780951c8289693095eb56b4653d4b9353cf0b083 the fusion is good enough to remove the custom rope kernel! There's a slight performance disadvantage (17.8 tps custom vs 17.0 tps automatic) mainly due to not fusing in the subexpressions yet, but this approch demonstrates powerful kernel generation

jafioti avatar Mar 12 '24 19:03 jafioti

What happens if we do tensor.slice(..1).cos().pad([0, 1])? This causes the sliced out indexes to get passed through cos as 0, which means cos(0) = 1, so we get out 1s for the sliced out pieces, which are then padded back in.

Instead we need to insert indexing and valid expressions in between each and every component of the equation, not just stack them at the beginning.

jafioti avatar Apr 04 '24 19:04 jafioti

As of 589148707f5c53f479703b709d663257a21f3760 this is no longer a problem, as I have rewritten the entire fusion to now use subexpressions and properly do valid checks. Each subexpression now does a valid check, not just a single valid check at the beninning. As a side benifit, the kernels are much easier to read!

Only remaining issue is mistral still outputs jibberish and runs slow for some reason when fusion is on.

jafioti avatar Apr 07 '24 00:04 jafioti

Ok mistral is finally fixed. Current fusion is slower than before (~15 tps vs 17 tps) but it is correct, unlike before. The reason for the slowdown is almost definitely due to huge index and valid expressions. This will be solved in #47

jafioti avatar Apr 09 '24 23:04 jafioti