onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Add GQA support for ROCm

Open cloudhan opened this issue 1 year ago • 1 comments

depends on

  • [x] #20913
  • [x] #21028
  • [x] #21030

cloudhan avatar Jun 13 '24 09:06 cloudhan

CI test revealed something like the following

kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

.local/lib/python3.9/site-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py:58: in test_gqa_past_flash_attention
    parity_check_gqa_past(
/onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_cuda.py:1702: in parity_check_gqa_past
    numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7f28fe08e280>, array([[[[-8.6060e-03,  4.1046e-02, -2.5604e-02, ..., ...n, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan]]]], dtype=float16))
kwds = {'equal_nan': True, 'err_msg': ' with Config(batch_size=5, sequence_length=1, kv_sequence_length=2048, past_sequence_l...ue, rotary_interleaved=False, packed=True', 'header': 'Not equal to tolerance rtol=0.002, atol=0.002', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.002, atol=0.002
E            with Config(batch_size=5, sequence_length=1, kv_sequence_length=2048, past_sequence_length=227, num_heads=32, kv_num_heads=8, head_size=256, ep=rocm), causal=True, local=False, past_format=1, rotary=True, rotary_interleaved=False, packed=True
E           x and y nan location mismatch:
E            x: array([[[[-8.6060e-03,  4.1046e-02, -2.5604e-02, ..., -7.4829e-02,
E                      5.8060e-03, -2.0828e-03],
E                    [ 4.0207e-03,  7.6523e-03,  1.5244e-02, ..., -4.6326e-02,...
E            y: array([[[[nan, nan, nan, ..., nan, nan, nan],
E                    [nan, nan, nan, ..., nan, nan, nan],
E                    [nan, nan, nan, ..., nan, nan, nan],...

/opt/miniconda/envs/rocm-ci/lib/python3.9/contextlib.py:79: AssertionError

and some sparse 'inf' in other tests. This however, happened to the y value, aka, the reference value. I locally reproduced many of these issue and update torch (along with torch triton) to 2.3.1 eliminate all of them.

cloudhan avatar Jun 28 '24 07:06 cloudhan

LGTM except there is a build error:

https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1423537&view=logs&j=7536d2cd-87d4-54fe-4891-bfbbf2741d83&t=66420422-c7d6-5f71-625c-4b7851c9b9ba&l=3997

CMakeFiles/onnxruntime_providers_rocm.dir/onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu.o /onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu:5:10: fatal error: 'ck_tile/core/numeric/integer.hpp' file not found #include "ck_tile/core/numeric/integer.hpp"

tianleiwu avatar Jul 01 '24 17:07 tianleiwu

@snnn need an es approve. The some packages in CI are updated due to some nan and inf are produced from the reference impl, see my previous comment.

cloudhan avatar Jul 03 '24 02:07 cloudhan