luminal icon indicating copy to clipboard operation
luminal copied to clipboard

Llama example bug on M3

Open jorgeantonio21 opened this issue 1 year ago • 7 comments

I have been experimenting with the Luminal examples on my Macbook M3, and every time I run

cargo run --release --features metal

I get a non-sensical answer, like:

```python<|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|
...

jorgeantonio21 avatar Apr 28 '24 09:04 jorgeantonio21

This is a bug I've been aware of, I think it's due to some hardware differences between the M1 and M2 and M3. The example works on M1 and M2, but not on M3.

Would you be able to cd into crates/luminal_metal and run cargo test -- --test-threads 1 and send the output? I'm curious if all the tests pass for you

jafioti avatar Apr 28 '24 17:04 jafioti

Thanks for your reply ! Sure, the output is:

st tests::fp32::test_sub_4096 ... ok
test tests::fp32::test_sub_50 ... ok
test tests::fp32::test_sub_783 ... ok
test tests::fp32::test_sum_reduce ... ok
test tests::fp32::test_transformer_encoder_block ... ok
test unary::tests::test_norms ... ok

failures:

---- tests::fp16::test_pad_contig stdout ----
thread 'tests::fp16::test_pad_contig' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:80:5:
assertion `left == right` failed: Number of elements doesn't match
  left: 312
 right: 208
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

---- tests::fp16::test_softmax_3 stdout ----
thread 'tests::fp16::test_softmax_3' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.3190918, index 0, avg distance: inf

---- tests::fp16::test_softmax_4096 stdout ----
thread 'tests::fp16::test_softmax_4096' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.00040984154, index 0, avg distance: inf

---- tests::fp16::test_softmax_50 stdout ----
thread 'tests::fp16::test_softmax_50' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.026504517, index 0, avg distance: inf

---- tests::fp16::test_softmax_783 stdout ----
thread 'tests::fp16::test_softmax_783' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.0017166138, index 0, avg distance: inf

---- tests::fp32::test_softmax_3 stdout ----
thread 'tests::fp32::test_softmax_3' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.31912953, index 0, avg distance: inf

---- tests::fp32::test_softmax_4096 stdout ----
thread 'tests::fp32::test_softmax_4096' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.00032597585, index 0, avg distance: inf

---- tests::fp32::test_softmax_50 stdout ----
thread 'tests::fp32::test_softmax_50' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.026488636, index 0, avg distance: inf

---- tests::fp32::test_softmax_783 stdout ----
thread 'tests::fp32::test_softmax_783' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.0017388357, index 0, avg distance: inf


failures:
    tests::fp16::test_pad_contig
    tests::fp16::test_softmax_3
    tests::fp16::test_softmax_4096
    tests::fp16::test_softmax_50
    tests::fp16::test_softmax_783
    tests::fp32::test_softmax_3
    tests::fp32::test_softmax_4096
    tests::fp32::test_softmax_50
    tests::fp32::test_softmax_783

test result: FAILED. 170 passed; 9 failed; 0 ignored; 0 measured; 0 filtered out; finished in 17.13s

error: test failed, to rerun pass `--lib`

jorgeantonio21 avatar Apr 29 '24 10:04 jorgeantonio21

Also, the output token generation is very fast (around 80 tokens/sec), sometimes extremely slow (like 0.1 tokens/sec). I guess this issue will be resolved once better metal kernels are integrated for the M3

jorgeantonio21 avatar Apr 29 '24 10:04 jorgeantonio21

Hmm yeah the inf is a problem, that's why you're getting the strange outputs. The logits are probably nan or inf. I'll look into the differences between the M1/M2 and M3

jafioti avatar Apr 29 '24 14:04 jafioti

So the pad_contig fail is fine, I fixed that and it should work for you. There's something wrong in the handwritten softmax kernel for M3 that works on the other chips. I'll have to take a look and work it out, but at least it's narrowed down to one kernel.

jafioti avatar Apr 29 '24 15:04 jafioti

Great ! Thank you so much @jafioti for the time you put on this, let me know of any updates please.

jorgeantonio21 avatar Apr 29 '24 18:04 jorgeantonio21

Would you be able to comment out the MetalSoftmaxCompiler line in luminal_metal/src/lib.rs? It should be in the SpecialOpsCompiler.

If that works then we know the issue to be in the softmax

jafioti avatar May 06 '24 14:05 jafioti

@jorgeantonio21 Did you get a chance to try this with the softmax compiler turned off?

jafioti avatar May 28 '24 15:05 jafioti

Would you be able to comment out the MetalSoftmaxCompiler line in luminal_metal/src/lib.rs? It should be in the SpecialOpsCompiler.

If that works then we know the issue to be in the softmax

@jafioti Bingo! Commenting softmax compiler makes llama example work perfectly on M3 Max. Phi also works albeit with a bogus response.

mikeseven avatar Jun 04 '24 22:06 mikeseven

@mikeseven Awesome thanks for letting me know. I removed the softmax op entirely since it was only responsible for a couple millisecond gain and it will be subsumed by flash attention soon anyway.

jafioti avatar Jun 05 '24 15:06 jafioti

I responded in the phi on m3 issue. Just duplicating here. What I meant by bogus response for phi is a model accuracy issue, llama answer is more precise. Can't wait to test flash attention!

mikeseven avatar Jun 06 '24 17:06 mikeseven

Thank you for this @jafioti !

jorgeantonio21 avatar Jun 06 '24 21:06 jorgeantonio21