flash-attention
flash-attention copied to clipboard
Support AMD ROCm on FlashAttention 2
- This PR implement the AMD / ROCm version of c++ flash api
- mha_fwd
- mha_varlen_fwd
- mha_bwd
- mha_varlen_bwd
- The kernel implementation comes from composable kernel
- The c++ api is same as original version. Hence, python interface can be used in common.
+1 please merge
@tridao I would be very happy to see this change!
+1
+100
+1
For those with AMD devices can you help test this PR?
hi @rocking5566 I get this error when I try to install this.
- I clone main
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
- I get the remote and merge it to main
git remote add rocm2 https://github.com/ROCm/flash-attention.git
git fetch rocm2
git checkout -b rocm_merging rocm2/ck_tile
git checkout main
git merge rocm_merging
- I install flash-attention
pip install .
Here is the error message I see:
dist.fetch_build_eggs(dist.setup_requires)
running bdist_wheel
Traceback (most recent call last):
File "<string>", line 2, in <module>
File "<pip-setuptools-caller>", line 34, in <module>
File "/home/resmp/flash-attention/setup.py", line 443, in <module>
setup(
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/__init__.py", line 103, in setup
return distutils.core.setup(**attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 185, in setup
return run_commands(dist)
^^^^^^^^^^^^^^^^^^
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
dist.run_commands()
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
self.run_command(cmd)
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/dist.py", line 989, in run_command
super().run_command(command)
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
cmd_obj.run()
File "/home/resmp/flash-attention/setup.py", line 400, in run
wheel_url, wheel_filename = get_wheel_url()
^^^^^^^^^^^^^^^
File "/home/resmp/flash-attention/setup.py", line 369, in get_wheel_url
torch_cuda_version = parse(torch.version.cuda)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/packaging/version.py", line 54, in parse
return Version(version)
^^^^^^^^^^^^^^^^
File "/home/resmp/miniconda3/envs/axolotl/lib/python3.12/site-packages/packaging/version.py", line 198, in __init__
match = self._regex.search(version)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: expected string or bytes-like object, got 'NoneType'
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for flash-attn
Running setup.py clean for flash-attn
Failed to build flash-attn
ERROR: Could not build wheels for flash-attn, which is required to install pyproject.toml-based projects
I check if torch.version.hip is working correctly:
(axolotl) resmp@tw001:~/flash-attention$ python -c "import torch; print(torch.__version__); print(torch.version.hip)"
2.4.0.dev20240412+rocm6.0
6.0.32830-d62f6a171
(axolotl) resmp@tw001:~/flash-attention$
It indicates something hard-coded in setup.py that needs to be generalized
I ask Claude about this
To generalize the setup.py for both CUDA and ROCm environments, we need to make some modifications. Here's an overview of the key changes:
- Detect the GPU environment (CUDA or ROCm) automatically.
- Use conditional logic to set appropriate flags and file paths for each environment.
- Generalize the extension creation process.
Here are the main sections that need to be modified:
- GPU Environment Detection: Replace the current detection method with a more robust one:
IS_ROCM = False if torch.version.hip is not None: IS_ROCM = True elif torch.version.cuda is not None: IS_ROCM = False else: raise RuntimeError("Neither CUDA nor ROCm is available")
- Submodule Update: Use a conditional statement to update the correct submodule:
if not IS_ROCM: subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) else: subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
- Compiler Flags and Source Files: Create separate lists for CUDA and ROCm:
cuda_extra_compile_args = { "cxx": ["-O3", "-std=c++17"] + generator_flag, "nvcc": [ "-O3", "-std=c++17", # ... (other CUDA-specific flags) ] + generator_flag + cc_flag } rocm_extra_compile_args = { "cxx": ["-O3", "-std=c++17"] + generator_flag, "hipcc": [ "-O3", "-std=c++17", # ... (other ROCm-specific flags) ] + generator_flag + cc_flag } cuda_sources = [ "csrc/flash_attn/flash_api.cpp", # ... (list of CUDA source files) ] rocm_sources = [ "csrc/flash_attn_ck/flash_api.cpp", # ... (list of ROCm source files) ]
- Extension Creation: Use a conditional statement to create the appropriate extension:
if not IS_ROCM: ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", sources=cuda_sources, extra_compile_args=cuda_extra_compile_args, include_dirs=[ Path(this_dir) / "csrc" / "flash_attn", Path(this_dir) / "csrc" / "flash_attn" / "src", Path(this_dir) / "csrc" / "cutlass" / "include", ], ) ) else: ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", sources=rocm_sources, extra_compile_args=rocm_extra_compile_args, include_dirs=[ Path(this_dir) / "csrc" / "composable_kernel" / "include", Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", ], ) )
- Wheel URL Generation: Modify the
get_wheel_url()function to handle both CUDA and ROCm:def get_wheel_url(): # ... (existing code) if IS_ROCM: gpu_tag = f"rocm{torch.version.hip}" else: gpu_tag = f"cu{cuda_version}" wheel_filename = f"{PACKAGE_NAME}-{flash_version}+{gpu_tag}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" # ... (rest of the function)These changes will make the setup.py more generalized and able to handle both CUDA and ROCm environments. Remember to test thoroughly in both environments to ensure everything works as expected.
I am not qualified to make this sort of change! But, I hope this helps to narrow down the problem
@ehartford Thank for you valuable comment.
About compiling the code for Rocm,
You can try
python setup.py install
This is work for me.
I will take a look of your script
Thank you, running python setup.py worked. I will run a full build tonight using flash attention and verify that it's working
@tridao we tested this on MI300X and verified that it's working
@ehartford
I fix the ROCm environment detection and get_wheel_url()
pip install . works now
Will try this thank you
I got this error during the build fmha_bwd_d128_fp16_batch for gfx1100: /root/code/flash-attention/csrc/composable_kernel/include/ck_tile/core/arch/generic_memory_space_atomic_hip.hpp:66:19: error: static assertion failed due to requirement '(std::is_same<_Float16, int>::value && (4 == 1)) || (std::is_same<_Float16, unsigned int>::value && (4 == 1)) || (std::is_same<_Float16, float>::value && (4 == 1 || 4 == 2)) || (std::is_same<_Float16, double>::value && (4 == 1 || 4 == 2)) || (std::is_same<_Float16, unsigned short>::value && (4 == 2 || 4 == 4))': wrong! not implemented,
It was due to _Float16 not implemented in the instantiation of: /root/code/flash-attention/csrc/composable_kernel/include/ck_tile/core/tensor/buffer_view_hip.hpp:413:28: note: in instantiation of function template specialization 'ck_tile::buffer_view<ck_tile::address_space_enum::global, _Float16, long, true, ck_tile::amd_buffer_coherence_enum::coherence_default>::atomic_add<ck_tile::thread_buffer<_Float16, 4>, false>' requested here 413 | this->template atomic_add<X>(i, is_valid_element, x);
The template doesn't accepts T=float16, but accepts T=bf16 with N==2 or 4: template <typename T, index_t N> CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
Is this expected? I saw the allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card?
I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card?
They don't support it yet, native should mean for CPU. I digged into it, here are my findings. The first problem I met was float16 not implemented in the atomic_add in csrc/composable_kernel/include/ck_tile/core/arch/generic_memory_space_atomic.hpp. I still don't understand why MI300 doesn't need it, but we can disable the generate.py from generating kernels for fp16. There are 3 python files in csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/, we can modify them to skip for fp16. The second problem I'm dealing with is a configuration check: static_assert(kKPack % K3 == 0) in csrc/composable_kernel/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp The key was the wrap size of 7900xtxt is 32, but for mi200 and mi300, they have 64.
kBlockSize = NumWarps * warp_size = (1 * 4 * 1) * 32 = 128
kMPerBlock (M0) = 128
kNPerBlock (N0) = 128
kQKHeaddim = 32
BiasDataType = unsigned short (16 bits)
total_pixels = kMPerBlock * kNPerBlock / kBlockSize = 128 * 128 / 128 = 128
Since total_pixels > 32, N1 = 8
N0 = kNPerBlock / N1 = 128 / 8 = 16
total_pixels = kMPerBlock * kNPerBlock / kBlockSize = 128 * 128 / 128 = 128
M3 = total_pixels / N1 = 128 / 8 = 16
kKPack = 16 / sizeof(BiasDataType) = 16 / 2 = 8
8 % 16 != 0
Here is the problem template instanciation I figured out, hope this could help:
ck_tile::BlockFmhaBwdPipelineProblem<
unsigned short, unsigned short, unsigned short, unsigned short, float, float, float, unsigned short, unsigned char, unsigned short, unsigned short, unsigned short, unsigned short, unsigned short, unsigned short,
# QDataType, KDataType, VDataType, GemmDataType, LSEDataType, AccDataType, DDataType, BiasDataType, RandValOutputDataType, ODataType, OGradDataType, QGradDataType, KGradDataType, VGradDataType, BiasGradDataType
ck_tile::TileFmhaBwdShape< # BlockFmhaShape
ck_tile::sequence<128, 128, 32, 32, 32, 32, 32, 32, 32>, # BlockTile: [kM0, KN0, KK0, KK1, KK2, KK3, KK4, kQKHeaddim, kVHeaddim]
ck_tile::sequence<1, 4, 1>, # Gemm0BlockWarps
ck_tile::sequence<32, 32, 16>, # Gemm0WarpTile
ck_tile::sequence<4, 1, 1>, # Gemm1BlockWarps
ck_tile::sequence<32, 32, 16>, # Gemm1WarpTile
ck_tile::sequence<1, 4, 1>, # Gemm2BlockWarps
ck_tile::sequence<32, 32, 16>, # Gemm2WarpTile
ck_tile::sequence<4, 1, 1>, # Gemm3BlockWarps
ck_tile::sequence<32, 32, 16>, # Gemm3WarpTile
ck_tile::sequence<4, 1, 1>, # Gemm4BlockWarps
ck_tile::sequence<32, 32, 16> # Gemm4WarpTile
>, # NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}), kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(), wrap_size = 32
false, #kIsGroupMode
ck_tile::SimplifiedGenericAttentionMask<>, # FmhaMask
ck_tile::TileFmhaTraits<false, false, false, false, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, 1> #Traits
>
I'll keep digging into it. My plan is to generate a group a kernels fitting with 7900xtx, will post my result if I have any findings.
I tried this on a known good configuration, using TRL
I am able to run it without flash attention, and I am able to run it with the ROCm version of flash attention.
But using this PR - I get an error.
I also repro this same error in axolotl, so it's not a specific problem with TRL.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
File "/home/erichartford/flash-attention/./thingy.py", line 107, in <module>
trainer.train()
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 451, in train
output = super().train(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 3307, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 3338, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/peft/peft_model.py", line 1430, in forward
return self.base_model(
^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
return self.model.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 967, in forward
layer_outputs = self._gradient_checkpointing_func(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 255, in forward
outputs = run_function(*args)
^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 467, in forward
attn_output = self._flash_attention_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 532, in _flash_attention_forward
attn_output = flash_attn_func(
^^^^^^^^^^^^^^^^
File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 882, in flash_attn_func
return FlashAttnFunc.apply(
^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 548, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
^^^^^^^^^^^^^^^^^^^^
File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
^^^^^^^^^^^^^^^^^^^^
TypeError: fwd(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: float, arg6: float, arg7: bool, arg8: int, arg9: int, arg10: bool, arg11: Optional[torch.Generator]) -> list[torch.Tensor]
Invoked with: tensor([[[[ 1.3281e-01, 2.2461e-01, -3.9551e-02, ..., 3.5156e-02,
5.8838e-02, 1.4648e-01],
[-1.3733e-02, 3.4332e-03, -4.8633e-01, ..., 4.6631e-02,
2.4805e-01, 4.3701e-02],
....
[-2.1362e-02, -4.7607e-02, -2.6611e-02, ..., -1.8677e-02,
-3.3447e-02, -7.6660e-02]]]], device='cuda:0', dtype=torch.bfloat16), None, None, 0.0, 0.08838834764831845, True, -1, -1, 0.0, False, None
Also please note that - we were able to test Successfully when using this branch in the ROCm repo.
It is only when we merge that to the main branch of the DAO-AILab repo that this error presents.
We have independently reproduced this error on several machines, and several environments.
@rocking5566 we would be happy to have a call and troubleshoot this with you, if you would like to reach out. [email protected]
@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the allowed_archs list. Other targets should be blocked anyway..
@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the
allowed_archslist. Other targets should be blocked anyway..
Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet?
How about gfx908?
Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet? How about gfx908?
@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms.
now that the conflicts are resolved, can we merge this?
thanks!
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper
Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet? How about gfx908?
@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms.
You can use this script to query arch and exclude any archs not well being testified (only gfx942 tested)
target_amdarch =$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*')
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper
gfx906 and gfx908 could at least compile, they have warp_size = 64.
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper
gfx906 and gfx908 could at least compile, they have warp_size = 64.
Current FA could compile successfully in MI100. However, I found some test cases might fail.... We may fix it in the future.
Thats excellent news! I cannot wait to try it on my mi100s
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了
gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。
Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。
mi100 worked , but mi50 failed on all test
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了
gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。
Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。
mi100 worked , but mi50 failed on all test
That is why we only claim MI200 & MI300 are officially support. Other platform might failed on some test cases for some version of ROCm.
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了
gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。
Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。
mi100 worked , but mi50 failed on all test
That is why we only claim MI200 & MI300 are officially support. Other platform might failed on some test cases for some version of ROCm.
I installed rocm 6.1x on mi50, all test failed -.-
hoping for mi50 support~
I am wondering why kv cache/ paged attention API fwd_kvcache not supported.