Add GQA support for ROCm
depends on
- [x] #20913
- [x] #21028
- [x] #21030
CI test revealed something like the following
kw = {}
@wraps(func)
def standalone_func(*a, **kw):
> return func(*(a + p.args), **p.kwargs, **kw)
.local/lib/python3.9/site-packages/parameterized/parameterized.py:620:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py:58: in test_gqa_past_flash_attention
parity_check_gqa_past(
/onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_cuda.py:1702: in parity_check_gqa_past
numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x7f28fe08e280>, array([[[[-8.6060e-03, 4.1046e-02, -2.5604e-02, ..., ...n, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]]], dtype=float16))
kwds = {'equal_nan': True, 'err_msg': ' with Config(batch_size=5, sequence_length=1, kv_sequence_length=2048, past_sequence_l...ue, rotary_interleaved=False, packed=True', 'header': 'Not equal to tolerance rtol=0.002, atol=0.002', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0.002, atol=0.002
E with Config(batch_size=5, sequence_length=1, kv_sequence_length=2048, past_sequence_length=227, num_heads=32, kv_num_heads=8, head_size=256, ep=rocm), causal=True, local=False, past_format=1, rotary=True, rotary_interleaved=False, packed=True
E x and y nan location mismatch:
E x: array([[[[-8.6060e-03, 4.1046e-02, -2.5604e-02, ..., -7.4829e-02,
E 5.8060e-03, -2.0828e-03],
E [ 4.0207e-03, 7.6523e-03, 1.5244e-02, ..., -4.6326e-02,...
E y: array([[[[nan, nan, nan, ..., nan, nan, nan],
E [nan, nan, nan, ..., nan, nan, nan],
E [nan, nan, nan, ..., nan, nan, nan],...
/opt/miniconda/envs/rocm-ci/lib/python3.9/contextlib.py:79: AssertionError
and some sparse 'inf' in other tests. This however, happened to the y value, aka, the reference value. I locally reproduced many of these issue and update torch (along with torch triton) to 2.3.1 eliminate all of them.
LGTM except there is a build error:
https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1423537&view=logs&j=7536d2cd-87d4-54fe-4891-bfbbf2741d83&t=66420422-c7d6-5f71-625c-4b7851c9b9ba&l=3997
CMakeFiles/onnxruntime_providers_rocm.dir/onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu.o /onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu:5:10: fatal error: 'ck_tile/core/numeric/integer.hpp' file not found #include "ck_tile/core/numeric/integer.hpp"
@snnn need an es approve. The some packages in CI are updated due to some nan and inf are produced from the reference impl, see my previous comment.