triton
triton copied to clipboard
Flash Attention test fails for bfloat16
Hi, I have enabled bfloat16 testing in Triton (https://github.com/openai/triton/pull/1244/), but I'm getting this error with this data type
giorgio@giorgio:triton$ pytest python/test/unit/operators/test_flash_attention.py -s
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.10.9, pytest-7.2.1, pluggy-1.0.0
rootdir: /usr/local/home/giorgio/triton/python
collected 2 items
python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Fatal Python error: Aborted
Thread 0x00007fbc757fe6c0 (most recent call first):
<no Python frame>
Current thread 0x00007fbd053ae200 (most recent call first):
File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1018 in ttgir_to_llir
File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1570 in <lambda>
File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1637 in compile
File "<string>", line 41 in _fwd_kernel
File "/usr/local/home/giorgio/triton/python/triton/ops/flash_attention.py", line 214 in forward
File "/usr/local/home/giorgio/triton/python/test/unit/operators/test_flash_attention.py", line 33 in test_op
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 1789 in runtest
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 260 in <lambda>
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 339 in from_call
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 220 in call_and_report
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
File "/usr/local/home/giorgio/.local/bin/pytest", line 8 in <module>
Extension modules: torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg.lapack_lite, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, __triton_launcher, cuda_utils (total: 23)
Aborted
Could this get fixed please? Thanks
LLVM didn't support bfloat16, so we used our customized type to represent bfloat16. Now that we've upgraded to LLVM head, there could be some improvements now.
Hi @Jokeren thank you for your prompt response! Are you referring to https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c? I've tried to run this on the latest commit (https://github.com/openai/triton/commit/a38d2defb8ad6f759c5f422d1a1c9848c6310355 at the time of writing), and it still gives the same result.
No, I meant after https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c, we can modify our previous bfloat16 code a bit. Will take a look
Oh right, sorry about the confusion. Thanks!
@giorgio-arena tested it just few hours ago with latest code. It worked with bf16.
@ptillet @daadaada
The problem might be caused by this PR https://github.com/openai/triton/pull/1107web
We've special handled shared_layout
->dot_layout
conversion by treating 2xi16 and 4xi8 as a single i32, but haven't handled mma_layout
->dot_layout
yet.
It's a bit trickier because mma_layout
stores two separated i16s by default. If we concatenate two i16 together, likely we will apply or
and shl
on every element, causing additional overhead.
proof of concept changes. https://github.com/openai/triton/commit/087ad498d8a7cde113ae0f6f59d1f1fbbc41e9ee
Still facing accuracy problems
@Jokeren: just how bad is the accuracy?
It doesn't pass the test... I haven't looked into details, maybe I was wrong on some bit manipulation stuff
We're putting off a few fires now, but we'll look into this more closely once things cool down
Branch has been updated. It might be just a minor precision problem now? You are welcome to verify. I have no idea how precise bf16 should be.
tensor(-0.0002, device='cuda:0', dtype=torch.bfloat16) tensor(-5.1223e-09, device='cuda:0', dtype=torch.bfloat16)
It's not related with https://github.com/openai/triton/pull/1267 because the problem persists even using ptxas -O0
Hi, I'm still getting the same error
python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing 'LLVMTranslationDialectInterface' registration for dialect for op: builtin.unrealized_conversion_cast
Even when doing
git fetch && git checkout origin/keren/fix-bf16
Am I doing something wrong? How did you test this?
I cannot reproduce the issue locally. Check you've done the following steps:
rm -rf build
pip install -e .
rm -rf ~/.triton/cache
pip uninstall pytorch-triton -y
If still observing the same error, please copy and paste the generated ttgir under ~/.triton/cache
Hi @Jokeren, thank you for that, you were right, I wasn't rebuilding properly! :) Also, the numerical divergence that I'm getting is not too bad
Arrays are not almost equal to 2 decimals
Mismatched elements: 203 / 12582912 (0.00161%)
Max absolute difference: 0.03125
Could I ask what the status on merging this branch to main
is? Is there any plan to do so anytime soon? Thanks
So @tridao suggested me testing against the original flash attention first, and I will have to confirm with @ptillet on this before the merge
Honestly I think we can just merge. The max divergence seems very reasonable considering that bfloat16 has less mantissa bits than float16
I see, but we have to update the error limits first, right? I'll do that some time later this week.
yeah, or we can reduce the input size. Seems like these errors are pretty unlikely
To make it more consistent with existing test, I only reduced the decimal limit for v
and it passed
I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.
I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.
testcase for bfloat16 type : python/test/unit/language/test_core.py