triton icon indicating copy to clipboard operation
triton copied to clipboard

Flash Attention test fails for bfloat16

Open giorgio-arena opened this issue 2 years ago • 22 comments

Hi, I have enabled bfloat16 testing in Triton (https://github.com/openai/triton/pull/1244/), but I'm getting this error with this data type

giorgio@giorgio:triton$ pytest python/test/unit/operators/test_flash_attention.py -s
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.10.9, pytest-7.2.1, pluggy-1.0.0
rootdir: /usr/local/home/giorgio/triton/python
collected 2 items                                                                                                                                                                                       

python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Fatal Python error: Aborted

Thread 0x00007fbc757fe6c0 (most recent call first):
  <no Python frame>

Current thread 0x00007fbd053ae200 (most recent call first):
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1018 in ttgir_to_llir
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1570 in <lambda>
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1637 in compile
  File "<string>", line 41 in _fwd_kernel
  File "/usr/local/home/giorgio/triton/python/triton/ops/flash_attention.py", line 214 in forward
  File "/usr/local/home/giorgio/triton/python/test/unit/operators/test_flash_attention.py", line 33 in test_op
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 1789 in runtest
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 339 in from_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/usr/local/home/giorgio/.local/bin/pytest", line 8 in <module>

Extension modules: torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg.lapack_lite, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, __triton_launcher, cuda_utils (total: 23)
Aborted

Could this get fixed please? Thanks

giorgio-arena avatar Feb 23 '23 15:02 giorgio-arena

LLVM didn't support bfloat16, so we used our customized type to represent bfloat16. Now that we've upgraded to LLVM head, there could be some improvements now.

Jokeren avatar Feb 23 '23 16:02 Jokeren

Hi @Jokeren thank you for your prompt response! Are you referring to https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c? I've tried to run this on the latest commit (https://github.com/openai/triton/commit/a38d2defb8ad6f759c5f422d1a1c9848c6310355 at the time of writing), and it still gives the same result.

giorgio-arena avatar Feb 23 '23 17:02 giorgio-arena

No, I meant after https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c, we can modify our previous bfloat16 code a bit. Will take a look

Jokeren avatar Feb 23 '23 17:02 Jokeren

Oh right, sorry about the confusion. Thanks!

giorgio-arena avatar Feb 23 '23 17:02 giorgio-arena

@giorgio-arena tested it just few hours ago with latest code. It worked with bf16.

linxihui avatar Feb 23 '23 18:02 linxihui

@ptillet @daadaada The problem might be caused by this PR https://github.com/openai/triton/pull/1107web We've special handled shared_layout->dot_layout conversion by treating 2xi16 and 4xi8 as a single i32, but haven't handled mma_layout->dot_layout yet.

It's a bit trickier because mma_layout stores two separated i16s by default. If we concatenate two i16 together, likely we will apply or and shl on every element, causing additional overhead.

Jokeren avatar Feb 24 '23 08:02 Jokeren

proof of concept changes. https://github.com/openai/triton/commit/087ad498d8a7cde113ae0f6f59d1f1fbbc41e9ee

Still facing accuracy problems

Jokeren avatar Feb 24 '23 08:02 Jokeren

@Jokeren: just how bad is the accuracy?

sbodenstein avatar Feb 27 '23 10:02 sbodenstein

It doesn't pass the test... I haven't looked into details, maybe I was wrong on some bit manipulation stuff

Jokeren avatar Feb 27 '23 17:02 Jokeren

We're putting off a few fires now, but we'll look into this more closely once things cool down

ptillet avatar Feb 27 '23 19:02 ptillet

Branch has been updated. It might be just a minor precision problem now? You are welcome to verify. I have no idea how precise bf16 should be.

tensor(-0.0002, device='cuda:0', dtype=torch.bfloat16) tensor(-5.1223e-09, device='cuda:0', dtype=torch.bfloat16)

Jokeren avatar Mar 02 '23 02:03 Jokeren

It's not related with https://github.com/openai/triton/pull/1267 because the problem persists even using ptxas -O0

Jokeren avatar Mar 02 '23 02:03 Jokeren

Hi, I'm still getting the same error python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing 'LLVMTranslationDialectInterface' registration for dialect for op: builtin.unrealized_conversion_cast Even when doing git fetch && git checkout origin/keren/fix-bf16 Am I doing something wrong? How did you test this?

giorgio-arena avatar Mar 02 '23 12:03 giorgio-arena

I cannot reproduce the issue locally. Check you've done the following steps:

rm -rf build
pip install -e .
rm -rf ~/.triton/cache
pip uninstall pytorch-triton -y

If still observing the same error, please copy and paste the generated ttgir under ~/.triton/cache

Jokeren avatar Mar 02 '23 17:03 Jokeren

Hi @Jokeren, thank you for that, you were right, I wasn't rebuilding properly! :) Also, the numerical divergence that I'm getting is not too bad

Arrays are not almost equal to 2 decimals

Mismatched elements: 203 / 12582912 (0.00161%)
Max absolute difference: 0.03125

Could I ask what the status on merging this branch to main is? Is there any plan to do so anytime soon? Thanks

giorgio-arena avatar Mar 08 '23 13:03 giorgio-arena

So @tridao suggested me testing against the original flash attention first, and I will have to confirm with @ptillet on this before the merge

Jokeren avatar Mar 08 '23 17:03 Jokeren

Honestly I think we can just merge. The max divergence seems very reasonable considering that bfloat16 has less mantissa bits than float16

ptillet avatar Mar 08 '23 17:03 ptillet

I see, but we have to update the error limits first, right? I'll do that some time later this week.

Jokeren avatar Mar 08 '23 17:03 Jokeren

yeah, or we can reduce the input size. Seems like these errors are pretty unlikely

ptillet avatar Mar 08 '23 17:03 ptillet

To make it more consistent with existing test, I only reduced the decimal limit for v and it passed

Jokeren avatar Mar 09 '23 02:03 Jokeren

I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.

Alon-Lau avatar Apr 23 '24 10:04 Alon-Lau

I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.

testcase for bfloat16 type : python/test/unit/language/test_core.py

Alon-Lau avatar Apr 23 '24 10:04 Alon-Lau