fix import flashinfer error on AMD GPUs
Getting CUDA arch flags raises error during importing flashinfer on AMD GPUs. As flashinfer is not officially supported on AMD platform, just disable it when flashinfer is not installed.
@akaitsuki-ii I think this fix needs to be revisited, due to:
- When flashinfer is installed on AMD GPUs,
HAS_FLASHINFERgets set toTruein globals when e.g. MI300X returns (9, 4) fromtorch.cuda.get_device_capability()- (
TORCH_CUDA_ARCH_LISTis NVIDIA specific envvar)
- (
torch_cpp_ext._get_cuda_arch_flags()function is NVIDIA specific so that will crash on AMD GPUs whenHAS_FLASHINFER = True- (There exists a rocm counterpart called _get_rocm_arch_flags)
If disabling flashinfer is desired on AMD GPUs, then the HAS_FLASHINFER check at globals could be augmented with torch.version.hip, so something like:
try:
from flashinfer.prefill import single_prefill_with_kv_cache
if torch.version.hip:
raise ImportError("FlashInfer not supported on AMD GPUs")
HAS_FLASHINFER = True
def get_cuda_arch():
major, minor = torch.cuda.get_device_capability()
return f"{major}.{minor}"
cuda_arch = get_cuda_arch()
os.environ['TORCH_CUDA_ARCH_LIST'] = cuda_arch
print(f"Set TORCH_CUDA_ARCH_LIST to {cuda_arch}")
except ImportError as e:
print("Warning: ", type(e).__name__, "–", e)
HAS_FLASHINFER = False
And the suggested
if HAS_FLASHINFER:
torch_cpp_ext._get_cuda_arch_flags()
Can be added/kept (as the purpose of that function is to raise an error if unknown CUDA arch is detected).
Hi dose this https://github.com/feifeibear/long-context-attention/pull/150 PR solve the problem.
@feifeibear Yes, LGTM. Thank you and let me close this PR.