flash-attention
flash-attention copied to clipboard
[CuteDSL] ICE for flash combine kernel on blackwell
Summary
Getting ICE for this test
Repro
git clone [email protected]:Dao-AILab/flash-attention.git
pip install flash_attn/cute/
pytest -v tests/cute/test_flash_attn.py -k "test_flash_attn_combine[1-1-64-dtype0]"
Output
============================= test session starts ==============================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0 -- /home/dev/.conda/envs/dev/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /home/dev/meta/flash-attention/tests
configfile: pyproject.toml
plugins: hypothesis-6.138.7, rerunfailures-16.0.1
collecting ... collected 3588 items / 3587 deselected / 1 selected
tests/cute/test_flash_attn.py::test_flash_attn_combine[1-1-64-dtype0] FAILED [100%]
=================================== FAILURES ===================================
____________________ test_flash_attn_combine[1-1-64-dtype0] ____________________
self = <cutlass.cutlass_dsl.cutlass.CuTeDSL object at 0x7f96fc81a090>
module = <cutlass._mlir._mlir_libs._mlir.ir.Module object at 0x7f96f9b717b0>
pipeline = 'builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3 enable-device-assertions=false link-libraries= toolkitPath=/usr/local/cuda-12.9 cubin-chip=sm_100a },external-kernel-for-gpu-launch)'
shared_libs = ['/home/dev/.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/lib/libmlir_cuda_runtime.so', '/home/dev/....ner_utils.so', '/home/dev/.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/lib/libmlir_runner_utils.so']
function_name = 'cutlass___call___flash_attncuteflash_fwd_combineFlashAttentionForwardCombine_object_at__Tensorgmemo1_Tensorgmemo1_Tensorgmemo1_Tensorgmem_o_1_None_None_None_None_CUstream0x0'
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
"""
Compile and JIT an MLIR module.
"""
try:
self.diagnostic()
orig_stdout = sys.stdout
orig_stderr = sys.stderr
sys.stderr = redirect_stderr = io.StringIO()
sys.stdout = redirect_stdout = io.StringIO()
try:
> kernel = self.compiler_provider.compile_and_jit(
module,
pipeline,
shared_libs=shared_libs,
cuda_toolkit=self.envar.cuda_toolkit,
arch=self.envar.arch,
)
../../.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:949:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py:178: in compile_and_jit
self.compile(
../../.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py:160: in compile
raise e
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <cutlass.base_dsl.compiler.Compiler object at 0x7f96fc81a0c0>
module = <cutlass._mlir._mlir_libs._mlir.ir.Module object at 0x7f96f9b717b0>
pipeline = 'builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3 enable-device-assertions=false link-libraries= toolkitPath=/usr/local/cuda-12.9 cubin-chip=sm_100a },external-kernel-for-gpu-launch)'
cuda_toolkit = '/usr/local/cuda-12.9', arch = 'sm_100a', enable_verifier = False
def compile(
self,
module,
pipeline: str,
cuda_toolkit: str = "",
arch: str = "",
enable_verifier=False,
):
"""Compiles the module by invoking the pipeline."""
try:
pm = self.passmanager.PassManager.parse(pipeline)
pm.enable_verifier(enable_verifier)
> pm.run(module.operation)
E cutlass._mlir._mlir_libs._site_initialize.<locals>.MLIRError: Failure while executing pass pipeline:
E error: unknown: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
E note: unknown: see current operation: %185 = "builtin.unrealized_conversion_cast"(%184) : (i64) -> i32
E error: unknown: Failed creating the llvm::Module.
E note: unknown: see current operation:
E "gpu.module"() <{sym_name = "kernels", targets = [#nvvm.target<O = 3, chip = "sm_100a">]}> ({
E "llvm.mlir.global"() <{addr_space = 3 : i32, alignment = 1024 : i64, dso_local, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage<external>, sym_name = "__dynamic_shmem__0", visibility_ = 0 : i64}> ({
E }) : () -> ()
E "llvm.func"() <{CConv = #llvm.cconv<ccc>, function_type = !llvm.func<void (struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>, struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, struct<(ptr<1>, struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>)>, i32, i32, i32, i32, i32, i32)>, linkage = #llvm.linkage<external>, sym_name = "kernel_cutlass_kernel_flash_attncuteflash_fwd_combineFlashAttentionForwardCombine_object_at__tensorptrf32gmemalign16odiv41div4div4div4_tensorptrf32gmemo1_tensorptrf32gmemalign16odiv41div4_0", visibility_ = 0 : i64}> ({
E ^bb0(%arg0: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>, %arg1: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, %arg2: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, %arg3: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>)>, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32):
E %0 = "llvm.mlir.undef"() : () -> vector<4xf32>
E %1 = "llvm.mlir.undef"() : () -> vector<1xf32>
E %2 = "llvm.mlir.constant"() <{value = 128 : i32}> : () -> i32
E %3 = "llvm.mlir.undef"() : () -> !llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>
E %4 = "llvm.mlir.undef"() : () -> !llvm.struct<(i32, i32, i32, i32)>
E %5 = "llvm.mlir.undef"() : () -> !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
E %6 = "llvm.mlir.undef"() : () -> !llvm.struct<(i32, i32)>
E %7 = "llvm.mlir.undef"() : () -> !llvm.struct<(i32, i32, i32)>
E %8 = "llvm.mlir.constant"() <{value = -1 : i64}> : () -> i64
E %9 = "llvm.mlir.constant"() <{value = false}> : () -> i1
E %10 = "llvm.mlir.constant"() <{value = 0 : i64}> : () -> i64
E %11 = "llvm.mlir.constant"() <{value = 1024 : i32}> : () -> i32
E %12 = "llvm.mlir.addressof"() <{global_name = @__dynamic_shmem__0}> : () -> !llvm.ptr<3>
E %13 = "llvm.mlir.constant"() <{value = 32 : i64}> : () -> i64
E %14 = "llvm.mlir.constant"() <{value = 8 : i64}> : () -> i64
E %15 = "llvm.mlir.constant"() <{value = 16 : i32}> : () -> i32
E %16 = "llvm.mlir.constant"() <{value = 0 : i32}> : () -> i32
E %17 = "llvm.mlir.constant"() <{value = 240 : i32}> : () -> i32
E %18 = "llvm.mlir.constant"() <{value = 0xFF800000 : f32}> : () -> f32
E %19 = "llvm.mlir.constant"() <{value = 64 : i32}> : () -> i32
E %20 = "llvm.mlir.constant"() <{value = -1 : i32}> : () -> i32
E %21 = "llvm.mlir.constant"() <{value = 3 : i32}> : () -> i32
E %22 = "llvm.mlir.constant"() <{value = 31 : i32}> : () -> i32
E %23 = "llvm.mlir.constant"() <{value = 2 : i32}> : () -> i32
E %24 = "llvm.mlir.constant"() <{value = 8 : i32}> : () -> i32
E %25 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
E %26 = "llvm.mlir.constant"() <{value = 1.44269502 : f32}> : () -> f32
E %27 = "llvm.mlir.constant"() <{value = 0.693147182 : f32}> : () -> f32
E %28 = "llvm.mlir.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
E %29 = "llvm.mlir.constant"() <{value = dense<0xFF800000> : vector<1xf32>}> : () -> vector<1xf32>
E %30 = "llvm.mlir.constant"() <{value = dense<0.000000e+00> : vector<4xf32>}> : () -> vector<4xf32>
E %31 = "llvm.mlir.constant"() <{value = 1 : i32}> : () -> i32
E %32 = "llvm.mlir.constant"() <{value = 4 : i32}> : () -> i32
E %33 = "llvm.alloca"(%32) <{alignment = 16 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E %34 = "llvm.alloca"(%32) <{alignment = 16 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E %35 = "llvm.alloca"(%32) <{alignment = 16 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E %36 = "llvm.alloca"(%31) <{alignment = 4 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E %37 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E %38 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i32}> : (i32) -> !llvm.ptr
E %39 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E %40 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i64}> : (i32) -> !llvm.ptr
E %41 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i32}> : (i32) -> !llvm.ptr
E %42 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i32}> : (i32) -> !llvm.ptr
E %43 = "llvm.extractvalue"(%arg0) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>) -> !llvm.ptr<1>
E %44 = "llvm.extractvalue"(%arg1) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>) -> !llvm.ptr<1>
E %45 = "llvm.extractvalue"(%arg2) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>) -> !llvm.ptr<1>
E %46 = "llvm.extractvalue"(%arg3) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>)>) -> !llvm.ptr<1>
E %47 = "nvvm.read.ptx.sreg.tid.x"() <{range = #llvm.constant_range<i32, 0, 1024>}> : () -> i32
E %48 = "nvvm.read.ptx.sreg.ctaid.x"() <{range = #llvm.constant_range<i32, 0, 2147483647>}> : () -> i32
E %49 = "nvvm.read.ptx.sreg.ctaid.y"() <{range = #llvm.constant_range<i32, 0, 65535>}> : () -> i32
E %50 = "nvvm.read.ptx.sreg.ctaid.z"() <{range = #llvm.constant_range<i32, 0, 65535>}> : () -> i32
E %51 = "llvm.getelementptr"(%12) <{elem_type = i8, rawConstantIndices = array<i32: 1024>}> : (!llvm.ptr<3>) -> !llvm.ptr<3>
E %52 = "llvm.getelementptr"(%12) <{elem_type = i8, rawConstantIndices = array<i32: 1152>}> : (!llvm.ptr<3>) -> !llvm.ptr<3>
E %53 = "llvm.extractvalue"(%arg1) <{position = array<i64: 1>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>) -> !llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>
E %54 = "llvm.extractvalue"(%53) <{position = array<i64: 0>}> : (!llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>) -> !llvm.struct<(i32, i32, i32, i32)>
E %55 = "llvm.extractvalue"(%54) <{position = array<i64: 1>}> : (!llvm.struct<(i32, i32, i32, i32)>) -> i32
E %56 = "llvm.extractvalue"(%arg0) <{position = array<i64: 1>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>) -> !llvm.struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>
E %57 = "llvm.extractvalue"(%56) <{position = array<i64: 0>}> : (!llvm.struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>) -> !llvm.struct<(i32, i32, i32, i32, i32)>
E %58 = "llvm.extractvalue"(%57) <{position = array<i64: 0>}> : (!llvm.struct<(i32, i32, i32, i32, i32)>) -> i32
E %59 = "llvm.extractvalue"(%57) <{position = array<i64: 3>}> : (!llvm.struct<(i32, i32, i32, i32, i32)>) -> i32
E %60 = "llvm.mul"(%58, %59) <{overflowFlags = #llvm.overflow<none>}> : (i32, i32) -> i32
E %61 = "llvm.sext"(%50) : (i32) -> i64
E %62 = "llvm.extractvalue"(%53) <{position = array<i64: 1>}> : (!llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>) -> !llvm.struct<(i32, i32, i32)>
E %63 = "llvm.extractvalue"(%62) <{position = array<i64: 2>}> : (!llvm.struct<(i32, i32, i32)>) -> i32
E %64 = "llvm.sext"(%63) : (i32) -> i64
E %65 = "llvm.mul"(%61, %64) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E %66 = "llvm.ptrtoint"(%44) : (!llvm.ptr<1>) -> i64
E %67 = "llvm.mul"(%65, %13) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E %68 = "llvm.sdiv"(%67, %14) : (i64, i64) -> i64
E %69 = "llvm.mul"(%68, %14) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E %70 = "llvm.icmp"(%67, %69) <{predicate = 1 : i64}> : (i64, i64) -> i1
E %71 = "llvm.icmp"(%67, %10) <{predicate = 2 : i64}> : (i64, i64) -> i1
E %72 = "llvm.icmp"(%71, %9) <{predicate = 1 : i64}> : (i1, i1) -> i1
E %73 = "llvm.and"(%70, %72) : (i1, i1) -> i1
E %74 = "llvm.add"(%68, %8) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E %75 = "llvm.select"(%73, %74, %68) <{fastmathFlags = #llvm.fastmath<none>}> : (i1, i64, i64) -> i64
E %76 = "llvm.add"(%66, %75) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E %77 = "llvm.inttoptr"(%76) : (i64) -> !llvm.ptr<1>
E %78 = "llvm.extractvalue"(%53) <{position = array<i64: 0, 0>}> : (!llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>) -> i32
E
...
TRUNCATED for length