flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Support AMD ROCm on FlashAttention 2

Open rocking5566 opened this issue 1 year ago • 20 comments
trafficstars

  • This PR implement the AMD / ROCm version of c++ flash api
    1. mha_fwd
    2. mha_varlen_fwd
    3. mha_bwd
    4. mha_varlen_bwd
  • The kernel implementation comes from composable kernel
  • The c++ api is same as original version. Hence, python interface can be used in common.

rocking5566 avatar Jun 26 '24 19:06 rocking5566

+1 please merge

deke997 avatar Jun 28 '24 03:06 deke997

@tridao I would be very happy to see this change!

ehartford avatar Jun 28 '24 19:06 ehartford

+1

larrysingh avatar Jun 28 '24 19:06 larrysingh

+100

Bellk17 avatar Jun 28 '24 19:06 Bellk17

+1

tomasikp avatar Jun 28 '24 20:06 tomasikp

For those with AMD devices can you help test this PR?

tridao avatar Jun 28 '24 20:06 tridao

We'd be happy to give you access at TensorWave

Please email me to get it set up:

[email protected]

deke997 avatar Jul 01 '24 01:07 deke997

hi @rocking5566 I get this error when I try to install this.

  1. I clone main
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
  1. I get the remote and merge it to main
git remote add rocm2 https://github.com/ROCm/flash-attention.git
git fetch rocm2
git checkout -b rocm_merging rocm2/ck_tile
git checkout main
git merge rocm_merging
  1. I install flash-attention
pip install .

Here is the error message I see:

        dist.fetch_build_eggs(dist.setup_requires)
      running bdist_wheel
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "/home/resmp/flash-attention/setup.py", line 443, in <module>
          setup(
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/__init__.py", line 103, in setup
          return distutils.core.setup(**attrs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 185, in setup
          return run_commands(dist)
                 ^^^^^^^^^^^^^^^^^^
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
          dist.run_commands()
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
          self.run_command(cmd)
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
          cmd_obj.run()
        File "/home/resmp/flash-attention/setup.py", line 400, in run
          wheel_url, wheel_filename = get_wheel_url()
                                      ^^^^^^^^^^^^^^^
        File "/home/resmp/flash-attention/setup.py", line 369, in get_wheel_url
          torch_cuda_version = parse(torch.version.cuda)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/packaging/version.py", line 54, in parse
          return Version(version)
                 ^^^^^^^^^^^^^^^^
        File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/packaging/version.py", line 198, in __init__
          match = self._regex.search(version)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      TypeError: expected string or bytes-like object, got 'NoneType'
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for flash-attn
  Running setup.py clean for flash-attn
Failed to build flash-attn
ERROR: Could not build wheels for flash-attn, which is required to install pyproject.toml-based projects

I check if torch.version.hip is working correctly:

(axolotl) resmp@tw001:~/flash-attention$ python -c "import torch; print(torch.__version__); print(torch.version.hip)"
2.4.0.dev20240412+rocm6.0
6.0.32830-d62f6a171
(axolotl) resmp@tw001:~/flash-attention$ 

It indicates something hard-coded in setup.py that needs to be generalized

I ask Claude about this

To generalize the setup.py for both CUDA and ROCm environments, we need to make some modifications. Here's an overview of the key changes:

  1. Detect the GPU environment (CUDA or ROCm) automatically.
  2. Use conditional logic to set appropriate flags and file paths for each environment.
  3. Generalize the extension creation process.

Here are the main sections that need to be modified:

  1. GPU Environment Detection: Replace the current detection method with a more robust one:
IS_ROCM = False
if torch.version.hip is not None:
    IS_ROCM = True
elif torch.version.cuda is not None:
    IS_ROCM = False
else:
    raise RuntimeError("Neither CUDA nor ROCm is available")
  1. Submodule Update: Use a conditional statement to update the correct submodule:
if not IS_ROCM:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
else:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
  1. Compiler Flags and Source Files: Create separate lists for CUDA and ROCm:
cuda_extra_compile_args = {
    "cxx": ["-O3", "-std=c++17"] + generator_flag,
    "nvcc": [
        "-O3", "-std=c++17",
        # ... (other CUDA-specific flags)
    ] + generator_flag + cc_flag
}

rocm_extra_compile_args = {
    "cxx": ["-O3", "-std=c++17"] + generator_flag,
    "hipcc": [
        "-O3", "-std=c++17",
        # ... (other ROCm-specific flags)
    ] + generator_flag + cc_flag
}

cuda_sources = [
    "csrc/flash_attn/flash_api.cpp",
    # ... (list of CUDA source files)
]

rocm_sources = [
    "csrc/flash_attn_ck/flash_api.cpp",
    # ... (list of ROCm source files)
]
  1. Extension Creation: Use a conditional statement to create the appropriate extension:
if not IS_ROCM:
    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=cuda_sources,
            extra_compile_args=cuda_extra_compile_args,
            include_dirs=[
                Path(this_dir) / "csrc" / "flash_attn",
                Path(this_dir) / "csrc" / "flash_attn" / "src",
                Path(this_dir) / "csrc" / "cutlass" / "include",
            ],
        )
    )
else:
    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=rocm_sources,
            extra_compile_args=rocm_extra_compile_args,
            include_dirs=[
                Path(this_dir) / "csrc" / "composable_kernel" / "include",
                Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
                Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
            ],
        )
    )
  1. Wheel URL Generation: Modify the get_wheel_url() function to handle both CUDA and ROCm:
def get_wheel_url():
    # ... (existing code)
    if IS_ROCM:
        gpu_tag = f"rocm{torch.version.hip}"
    else:
        gpu_tag = f"cu{cuda_version}"
    
    wheel_filename = f"{PACKAGE_NAME}-{flash_version}+{gpu_tag}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
    # ... (rest of the function)

These changes will make the setup.py more generalized and able to handle both CUDA and ROCm environments. Remember to test thoroughly in both environments to ensure everything works as expected.

I am not qualified to make this sort of change! But, I hope this helps to narrow down the problem

ehartford avatar Jul 03 '24 18:07 ehartford

@ehartford Thank for you valuable comment. About compiling the code for Rocm, You can try python setup.py install This is work for me. I will take a look of your script

rocking5566 avatar Jul 03 '24 19:07 rocking5566

Thank you, running python setup.py worked. I will run a full build tonight using flash attention and verify that it's working

ehartford avatar Jul 04 '24 00:07 ehartford

@tridao we tested this on MI300X and verified that it's working

deke997 avatar Jul 06 '24 01:07 deke997

@ehartford
I fix the ROCm environment detection and get_wheel_url() pip install . works now

rocking5566 avatar Jul 08 '24 15:07 rocking5566

Will try this thank you

ehartford avatar Jul 08 '24 15:07 ehartford

I got this error during the build fmha_bwd_d128_fp16_batch for gfx1100: /root/code/flash-attention/csrc/composable_kernel/include/ck_tile/core/arch/generic_memory_space_atomic_hip.hpp:66:19: error: static assertion failed due to requirement '(std::is_same<_Float16, int>::value && (4 == 1)) || (std::is_same<_Float16, unsigned int>::value && (4 == 1)) || (std::is_same<_Float16, float>::value && (4 == 1 || 4 == 2)) || (std::is_same<_Float16, double>::value && (4 == 1 || 4 == 2)) || (std::is_same<_Float16, unsigned short>::value && (4 == 2 || 4 == 4))': wrong! not implemented,

It was due to _Float16 not implemented in the instantiation of: /root/code/flash-attention/csrc/composable_kernel/include/ck_tile/core/tensor/buffer_view_hip.hpp:413:28: note: in instantiation of function template specialization 'ck_tile::buffer_view<ck_tile::address_space_enum::global, _Float16, long, true, ck_tile::amd_buffer_coherence_enum::coherence_default>::atomic_add<ck_tile::thread_buffer<_Float16, 4>, false>' requested here 413 | this->template atomic_add<X>(i, is_valid_element, x);

The template doesn't accepts T=float16, but accepts T=bf16 with N==2 or 4: template <typename T, index_t N> CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)

Is this expected? I saw the allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]

minzhezhou avatar Jul 12 '24 13:07 minzhezhou

I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card?

hackey avatar Jul 14 '24 09:07 hackey

I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card?

They don't support it yet, native should mean for CPU. I digged into it, here are my findings. The first problem I met was float16 not implemented in the atomic_add in csrc/composable_kernel/include/ck_tile/core/arch/generic_memory_space_atomic.hpp. I still don't understand why MI300 doesn't need it, but we can disable the generate.py from generating kernels for fp16. There are 3 python files in csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/, we can modify them to skip for fp16. The second problem I'm dealing with is a configuration check: static_assert(kKPack % K3 == 0) in csrc/composable_kernel/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp The key was the wrap size of 7900xtxt is 32, but for mi200 and mi300, they have 64.

kBlockSize = NumWarps * warp_size = (1 * 4 * 1) * 32 = 128
kMPerBlock (M0) = 128
kNPerBlock (N0) = 128
kQKHeaddim = 32
BiasDataType = unsigned short (16 bits)
total_pixels = kMPerBlock * kNPerBlock / kBlockSize = 128 * 128 / 128 = 128
Since total_pixels > 32, N1 = 8
N0 = kNPerBlock / N1 = 128 / 8 = 16
total_pixels = kMPerBlock * kNPerBlock / kBlockSize = 128 * 128 / 128 = 128
M3 = total_pixels / N1 = 128 / 8 = 16
kKPack = 16 / sizeof(BiasDataType) = 16 / 2 = 8
8 % 16 != 0

Here is the problem template instanciation I figured out, hope this could help:


ck_tile::BlockFmhaBwdPipelineProblem<
        unsigned short, unsigned short, unsigned short, unsigned short, float,          float,        float,        unsigned short, unsigned char,          unsigned short, unsigned short, unsigned short, unsigned short, unsigned short, unsigned short, 
        # QDataType,    KDataType,      VDataType,      GemmDataType,   LSEDataType,    AccDataType,  DDataType,    BiasDataType,   RandValOutputDataType,  ODataType,      OGradDataType,  QGradDataType,  KGradDataType,  VGradDataType,  BiasGradDataType    
        ck_tile::TileFmhaBwdShape<  # BlockFmhaShape
            ck_tile::sequence<128, 128, 32, 32, 32, 32, 32, 32, 32>,    # BlockTile: [kM0, KN0, KK0, KK1, KK2, KK3, KK4, kQKHeaddim, kVHeaddim]
            ck_tile::sequence<1, 4, 1>,         # Gemm0BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm0WarpTile
            ck_tile::sequence<4, 1, 1>,         # Gemm1BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm1WarpTile
            ck_tile::sequence<1, 4, 1>,         # Gemm2BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm2WarpTile
            ck_tile::sequence<4, 1, 1>,         # Gemm3BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm3WarpTile
            ck_tile::sequence<4, 1, 1>,         # Gemm4BlockWarps
            ck_tile::sequence<32, 32, 16>       # Gemm4WarpTile
        >, # NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}), kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(), wrap_size = 32 
        false, #kIsGroupMode
        ck_tile::SimplifiedGenericAttentionMask<>,  # FmhaMask
        ck_tile::TileFmhaTraits<false, false, false, false, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, 1>    #Traits
    >

I'll keep digging into it. My plan is to generate a group a kernels fitting with 7900xtx, will post my result if I have any findings.

minzhezhou avatar Jul 14 '24 10:07 minzhezhou

I tried this on a known good configuration, using TRL

I am able to run it without flash attention, and I am able to run it with the ROCm version of flash attention.

But using this PR - I get an error.

I also repro this same error in axolotl, so it's not a specific problem with TRL.

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
  File "/home/erichartford/flash-attention/./thingy.py", line 107, in <module>
    trainer.train()
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 451, in train
    output = super().train(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/peft/peft_model.py", line 1430, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 967, in forward
    layer_outputs = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 255, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 467, in forward
    attn_output = self._flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 532, in _flash_attention_forward
    attn_output = flash_attn_func(
                  ^^^^^^^^^^^^^^^^
  File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 882, in flash_attn_func
    return FlashAttnFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 548, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
                                                                ^^^^^^^^^^^^^^^^^^^^
TypeError: fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: float, arg6: float, arg7: bool, arg8: int, arg9: int, arg10: bool, arg11: Optional[torch.Generator]) -> list[torch.Tensor]

Invoked with: tensor([[[[ 1.3281e-01,  2.2461e-01, -3.9551e-02,  ...,  3.5156e-02,
            5.8838e-02,  1.4648e-01],
          [-1.3733e-02,  3.4332e-03, -4.8633e-01,  ...,  4.6631e-02,
            2.4805e-01,  4.3701e-02],
          ....
          [-2.1362e-02, -4.7607e-02, -2.6611e-02,  ..., -1.8677e-02,
           -3.3447e-02, -7.6660e-02]]]], device='cuda:0', dtype=torch.bfloat16), None, None, 0.0, 0.08838834764831845, True, -1, -1, 0.0, False, None

Also please note that - we were able to test Successfully when using this branch in the ROCm repo.

It is only when we merge that to the main branch of the DAO-AILab repo that this error presents.

We have independently reproduced this error on several machines, and several environments.

@rocking5566 we would be happy to have a call and troubleshoot this with you, if you would like to reach out. [email protected]

ehartford avatar Jul 16 '24 22:07 ehartford

@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the allowed_archs list. Other targets should be blocked anyway..

poyenc avatar Jul 17 '24 07:07 poyenc

@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the allowed_archs list. Other targets should be blocked anyway..

Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet?
How about gfx908?

minzhezhou avatar Jul 17 '24 10:07 minzhezhou

Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet? How about gfx908?

@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms.

poyenc avatar Jul 18 '24 08:07 poyenc

now that the conflicts are resolved, can we merge this?

thanks!

deke997 avatar Jul 23 '24 01:07 deke997

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper

linchen111 avatar Jul 24 '24 01:07 linchen111

Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet? How about gfx908?

@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms.

You can use this script to query arch and exclude any archs not well being testified (only gfx942 tested)

target_amdarch =$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*')

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper

gfx906 and gfx908 could at least compile, they have warp_size = 64.

minzhezhou avatar Jul 24 '24 10:07 minzhezhou

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper

gfx906 and gfx908 could at least compile, they have warp_size = 64.

Current FA could compile successfully in MI100. However, I found some test cases might fail.... We may fix it in the future.

rocking5566 avatar Jul 24 '24 18:07 rocking5566

Thats excellent news! I cannot wait to try it on my mi100s

ehartford avatar Jul 24 '24 18:07 ehartford

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了

gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。

Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。

mi100 worked , but mi50 failed on all test

linchen111 avatar Aug 01 '24 02:08 linchen111

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了

gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。

Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。

mi100 worked , but mi50 failed on all test

That is why we only claim MI200 & MI300 are officially support. Other platform might failed on some test cases for some version of ROCm.

rocking5566 avatar Aug 01 '24 09:08 rocking5566

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了

gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。

Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。

mi100 worked , but mi50 failed on all test

That is why we only claim MI200 & MI300 are officially support. Other platform might failed on some test cases for some version of ROCm.

I installed rocm 6.1x on mi50, all test failed -.-

hoping for mi50 support~

linchen111 avatar Aug 01 '24 09:08 linchen111

I am wondering why kv cache/ paged attention API fwd_kvcache not supported.

foreverlms avatar Aug 04 '24 14:08 foreverlms