luminal
luminal copied to clipboard
Llama example bug on M3
I have been experimenting with the Luminal examples on my Macbook M3, and every time I run
cargo run --release --features metal
I get a non-sensical answer, like:
```python<|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|><|reserved_special_token_250|
...
This is a bug I've been aware of, I think it's due to some hardware differences between the M1 and M2 and M3. The example works on M1 and M2, but not on M3.
Would you be able to cd into crates/luminal_metal and run cargo test -- --test-threads 1 and send the output? I'm curious if all the tests pass for you
Thanks for your reply ! Sure, the output is:
st tests::fp32::test_sub_4096 ... ok
test tests::fp32::test_sub_50 ... ok
test tests::fp32::test_sub_783 ... ok
test tests::fp32::test_sum_reduce ... ok
test tests::fp32::test_transformer_encoder_block ... ok
test unary::tests::test_norms ... ok
failures:
---- tests::fp16::test_pad_contig stdout ----
thread 'tests::fp16::test_pad_contig' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:80:5:
assertion `left == right` failed: Number of elements doesn't match
left: 312
right: 208
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
---- tests::fp16::test_softmax_3 stdout ----
thread 'tests::fp16::test_softmax_3' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.3190918, index 0, avg distance: inf
---- tests::fp16::test_softmax_4096 stdout ----
thread 'tests::fp16::test_softmax_4096' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.00040984154, index 0, avg distance: inf
---- tests::fp16::test_softmax_50 stdout ----
thread 'tests::fp16::test_softmax_50' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.026504517, index 0, avg distance: inf
---- tests::fp16::test_softmax_783 stdout ----
thread 'tests::fp16::test_softmax_783' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.0017166138, index 0, avg distance: inf
---- tests::fp32::test_softmax_3 stdout ----
thread 'tests::fp32::test_softmax_3' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.31912953, index 0, avg distance: inf
---- tests::fp32::test_softmax_4096 stdout ----
thread 'tests::fp32::test_softmax_4096' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.00032597585, index 0, avg distance: inf
---- tests::fp32::test_softmax_50 stdout ----
thread 'tests::fp32::test_softmax_50' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.026488636, index 0, avg distance: inf
---- tests::fp32::test_softmax_783 stdout ----
thread 'tests::fp32::test_softmax_783' panicked at /Users/jorgeantonio/dev/luminal/src/tests/mod.rs:83:13:
inf is not close to 0.0017388357, index 0, avg distance: inf
failures:
tests::fp16::test_pad_contig
tests::fp16::test_softmax_3
tests::fp16::test_softmax_4096
tests::fp16::test_softmax_50
tests::fp16::test_softmax_783
tests::fp32::test_softmax_3
tests::fp32::test_softmax_4096
tests::fp32::test_softmax_50
tests::fp32::test_softmax_783
test result: FAILED. 170 passed; 9 failed; 0 ignored; 0 measured; 0 filtered out; finished in 17.13s
error: test failed, to rerun pass `--lib`
Also, the output token generation is very fast (around 80 tokens/sec), sometimes extremely slow (like 0.1 tokens/sec). I guess this issue will be resolved once better metal kernels are integrated for the M3
Hmm yeah the inf is a problem, that's why you're getting the strange outputs. The logits are probably nan or inf. I'll look into the differences between the M1/M2 and M3
So the pad_contig fail is fine, I fixed that and it should work for you. There's something wrong in the handwritten softmax kernel for M3 that works on the other chips. I'll have to take a look and work it out, but at least it's narrowed down to one kernel.
Great ! Thank you so much @jafioti for the time you put on this, let me know of any updates please.
Would you be able to comment out the MetalSoftmaxCompiler line in luminal_metal/src/lib.rs? It should be in the SpecialOpsCompiler.
If that works then we know the issue to be in the softmax
@jorgeantonio21 Did you get a chance to try this with the softmax compiler turned off?
Would you be able to comment out the MetalSoftmaxCompiler line in luminal_metal/src/lib.rs? It should be in the SpecialOpsCompiler.
If that works then we know the issue to be in the softmax
@jafioti Bingo! Commenting softmax compiler makes llama example work perfectly on M3 Max. Phi also works albeit with a bogus response.
@mikeseven Awesome thanks for letting me know. I removed the softmax op entirely since it was only responsible for a couple millisecond gain and it will be subsumed by flash attention soon anyway.
I responded in the phi on m3 issue. Just duplicating here. What I meant by bogus response for phi is a model accuracy issue, llama answer is more precise. Can't wait to test flash attention!
Thank you for this @jafioti !