flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

[CuteDSL] ICE for flash combine kernel on blackwell

Open drisspg opened this issue 1 month ago • 0 comments

Summary

Getting ICE for this test

Repro

git clone [email protected]:Dao-AILab/flash-attention.git  
pip install flash_attn/cute/  
pytest -v tests/cute/test_flash_attn.py -k "test_flash_attn_combine[1-1-64-dtype0]"

Output

============================= test session starts ==============================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0 -- /home/dev/.conda/envs/dev/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /home/dev/meta/flash-attention/tests
configfile: pyproject.toml
plugins: hypothesis-6.138.7, rerunfailures-16.0.1
collecting ... collected 3588 items / 3587 deselected / 1 selected

tests/cute/test_flash_attn.py::test_flash_attn_combine[1-1-64-dtype0] FAILED [100%]

=================================== FAILURES ===================================
____________________ test_flash_attn_combine[1-1-64-dtype0] ____________________

self = <cutlass.cutlass_dsl.cutlass.CuTeDSL object at 0x7f96fc81a090>
module = <cutlass._mlir._mlir_libs._mlir.ir.Module object at 0x7f96f9b717b0>
pipeline = 'builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3 enable-device-assertions=false link-libraries= toolkitPath=/usr/local/cuda-12.9 cubin-chip=sm_100a },external-kernel-for-gpu-launch)'
shared_libs = ['/home/dev/.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/lib/libmlir_cuda_runtime.so', '/home/dev/....ner_utils.so', '/home/dev/.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/lib/libmlir_runner_utils.so']
function_name = 'cutlass___call___flash_attncuteflash_fwd_combineFlashAttentionForwardCombine_object_at__Tensorgmemo1_Tensorgmemo1_Tensorgmemo1_Tensorgmem_o_1_None_None_None_None_CUstream0x0'

    def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
        """
        Compile and JIT an MLIR module.
        """
    
        try:
            self.diagnostic()
    
            orig_stdout = sys.stdout
            orig_stderr = sys.stderr
            sys.stderr = redirect_stderr = io.StringIO()
            sys.stdout = redirect_stdout = io.StringIO()
    
            try:
>               kernel = self.compiler_provider.compile_and_jit(
                    module,
                    pipeline,
                    shared_libs=shared_libs,
                    cuda_toolkit=self.envar.cuda_toolkit,
                    arch=self.envar.arch,
                )

../../.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:949: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py:178: in compile_and_jit
    self.compile(
../../.conda/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py:160: in compile
    raise e
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <cutlass.base_dsl.compiler.Compiler object at 0x7f96fc81a0c0>
module = <cutlass._mlir._mlir_libs._mlir.ir.Module object at 0x7f96f9b717b0>
pipeline = 'builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3 enable-device-assertions=false link-libraries= toolkitPath=/usr/local/cuda-12.9 cubin-chip=sm_100a },external-kernel-for-gpu-launch)'
cuda_toolkit = '/usr/local/cuda-12.9', arch = 'sm_100a', enable_verifier = False

    def compile(
        self,
        module,
        pipeline: str,
        cuda_toolkit: str = "",
        arch: str = "",
        enable_verifier=False,
    ):
        """Compiles the module by invoking the pipeline."""
        try:
            pm = self.passmanager.PassManager.parse(pipeline)
            pm.enable_verifier(enable_verifier)
>           pm.run(module.operation)
E           cutlass._mlir._mlir_libs._site_initialize.<locals>.MLIRError: Failure while executing pass pipeline:
E           error: unknown: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
E            note: unknown: see current operation: %185 = "builtin.unrealized_conversion_cast"(%184) : (i64) -> i32
E           error: unknown: Failed creating the llvm::Module.
E            note: unknown: see current operation: 
E             "gpu.module"() <{sym_name = "kernels", targets = [#nvvm.target<O = 3, chip = "sm_100a">]}> ({
E               "llvm.mlir.global"() <{addr_space = 3 : i32, alignment = 1024 : i64, dso_local, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage<external>, sym_name = "__dynamic_shmem__0", visibility_ = 0 : i64}> ({
E               }) : () -> ()
E               "llvm.func"() <{CConv = #llvm.cconv<ccc>, function_type = !llvm.func<void (struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>, struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, struct<(ptr<1>, struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>)>, i32, i32, i32, i32, i32, i32)>, linkage = #llvm.linkage<external>, sym_name = "kernel_cutlass_kernel_flash_attncuteflash_fwd_combineFlashAttentionForwardCombine_object_at__tensorptrf32gmemalign16odiv41div4div4div4_tensorptrf32gmemo1_tensorptrf32gmemalign16odiv41div4_0", visibility_ = 0 : i64}> ({
E               ^bb0(%arg0: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>, %arg1: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, %arg2: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>, %arg3: !llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>)>, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32):
E                 %0 = "llvm.mlir.undef"() : () -> vector<4xf32>
E                 %1 = "llvm.mlir.undef"() : () -> vector<1xf32>
E                 %2 = "llvm.mlir.constant"() <{value = 128 : i32}> : () -> i32
E                 %3 = "llvm.mlir.undef"() : () -> !llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>
E                 %4 = "llvm.mlir.undef"() : () -> !llvm.struct<(i32, i32, i32, i32)>
E                 %5 = "llvm.mlir.undef"() : () -> !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
E                 %6 = "llvm.mlir.undef"() : () -> !llvm.struct<(i32, i32)>
E                 %7 = "llvm.mlir.undef"() : () -> !llvm.struct<(i32, i32, i32)>
E                 %8 = "llvm.mlir.constant"() <{value = -1 : i64}> : () -> i64
E                 %9 = "llvm.mlir.constant"() <{value = false}> : () -> i1
E                 %10 = "llvm.mlir.constant"() <{value = 0 : i64}> : () -> i64
E                 %11 = "llvm.mlir.constant"() <{value = 1024 : i32}> : () -> i32
E                 %12 = "llvm.mlir.addressof"() <{global_name = @__dynamic_shmem__0}> : () -> !llvm.ptr<3>
E                 %13 = "llvm.mlir.constant"() <{value = 32 : i64}> : () -> i64
E                 %14 = "llvm.mlir.constant"() <{value = 8 : i64}> : () -> i64
E                 %15 = "llvm.mlir.constant"() <{value = 16 : i32}> : () -> i32
E                 %16 = "llvm.mlir.constant"() <{value = 0 : i32}> : () -> i32
E                 %17 = "llvm.mlir.constant"() <{value = 240 : i32}> : () -> i32
E                 %18 = "llvm.mlir.constant"() <{value = 0xFF800000 : f32}> : () -> f32
E                 %19 = "llvm.mlir.constant"() <{value = 64 : i32}> : () -> i32
E                 %20 = "llvm.mlir.constant"() <{value = -1 : i32}> : () -> i32
E                 %21 = "llvm.mlir.constant"() <{value = 3 : i32}> : () -> i32
E                 %22 = "llvm.mlir.constant"() <{value = 31 : i32}> : () -> i32
E                 %23 = "llvm.mlir.constant"() <{value = 2 : i32}> : () -> i32
E                 %24 = "llvm.mlir.constant"() <{value = 8 : i32}> : () -> i32
E                 %25 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
E                 %26 = "llvm.mlir.constant"() <{value = 1.44269502 : f32}> : () -> f32
E                 %27 = "llvm.mlir.constant"() <{value = 0.693147182 : f32}> : () -> f32
E                 %28 = "llvm.mlir.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
E                 %29 = "llvm.mlir.constant"() <{value = dense<0xFF800000> : vector<1xf32>}> : () -> vector<1xf32>
E                 %30 = "llvm.mlir.constant"() <{value = dense<0.000000e+00> : vector<4xf32>}> : () -> vector<4xf32>
E                 %31 = "llvm.mlir.constant"() <{value = 1 : i32}> : () -> i32
E                 %32 = "llvm.mlir.constant"() <{value = 4 : i32}> : () -> i32
E                 %33 = "llvm.alloca"(%32) <{alignment = 16 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E                 %34 = "llvm.alloca"(%32) <{alignment = 16 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E                 %35 = "llvm.alloca"(%32) <{alignment = 16 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E                 %36 = "llvm.alloca"(%31) <{alignment = 4 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E                 %37 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E                 %38 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i32}> : (i32) -> !llvm.ptr
E                 %39 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = f32}> : (i32) -> !llvm.ptr
E                 %40 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i64}> : (i32) -> !llvm.ptr
E                 %41 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i32}> : (i32) -> !llvm.ptr
E                 %42 = "llvm.alloca"(%31) <{alignment = 32 : i64, elem_type = i32}> : (i32) -> !llvm.ptr
E                 %43 = "llvm.extractvalue"(%arg0) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>) -> !llvm.ptr<1>
E                 %44 = "llvm.extractvalue"(%arg1) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>) -> !llvm.ptr<1>
E                 %45 = "llvm.extractvalue"(%arg2) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>) -> !llvm.ptr<1>
E                 %46 = "llvm.extractvalue"(%arg3) <{position = array<i64: 0>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>)>) -> !llvm.ptr<1>
E                 %47 = "nvvm.read.ptx.sreg.tid.x"() <{range = #llvm.constant_range<i32, 0, 1024>}> : () -> i32
E                 %48 = "nvvm.read.ptx.sreg.ctaid.x"() <{range = #llvm.constant_range<i32, 0, 2147483647>}> : () -> i32
E                 %49 = "nvvm.read.ptx.sreg.ctaid.y"() <{range = #llvm.constant_range<i32, 0, 65535>}> : () -> i32
E                 %50 = "nvvm.read.ptx.sreg.ctaid.z"() <{range = #llvm.constant_range<i32, 0, 65535>}> : () -> i32
E                 %51 = "llvm.getelementptr"(%12) <{elem_type = i8, rawConstantIndices = array<i32: 1024>}> : (!llvm.ptr<3>) -> !llvm.ptr<3>
E                 %52 = "llvm.getelementptr"(%12) <{elem_type = i8, rawConstantIndices = array<i32: 1152>}> : (!llvm.ptr<3>) -> !llvm.ptr<3>
E                 %53 = "llvm.extractvalue"(%arg1) <{position = array<i64: 1>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>)>) -> !llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>
E                 %54 = "llvm.extractvalue"(%53) <{position = array<i64: 0>}> : (!llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>) -> !llvm.struct<(i32, i32, i32, i32)>
E                 %55 = "llvm.extractvalue"(%54) <{position = array<i64: 1>}> : (!llvm.struct<(i32, i32, i32, i32)>) -> i32
E                 %56 = "llvm.extractvalue"(%arg0) <{position = array<i64: 1>}> : (!llvm.struct<(ptr<1>, struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>)>) -> !llvm.struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>
E                 %57 = "llvm.extractvalue"(%56) <{position = array<i64: 0>}> : (!llvm.struct<(struct<(i32, i32, i32, i32, i32)>, struct<(i32, i32, i32, i32)>)>) -> !llvm.struct<(i32, i32, i32, i32, i32)>
E                 %58 = "llvm.extractvalue"(%57) <{position = array<i64: 0>}> : (!llvm.struct<(i32, i32, i32, i32, i32)>) -> i32
E                 %59 = "llvm.extractvalue"(%57) <{position = array<i64: 3>}> : (!llvm.struct<(i32, i32, i32, i32, i32)>) -> i32
E                 %60 = "llvm.mul"(%58, %59) <{overflowFlags = #llvm.overflow<none>}> : (i32, i32) -> i32
E                 %61 = "llvm.sext"(%50) : (i32) -> i64
E                 %62 = "llvm.extractvalue"(%53) <{position = array<i64: 1>}> : (!llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>) -> !llvm.struct<(i32, i32, i32)>
E                 %63 = "llvm.extractvalue"(%62) <{position = array<i64: 2>}> : (!llvm.struct<(i32, i32, i32)>) -> i32
E                 %64 = "llvm.sext"(%63) : (i32) -> i64
E                 %65 = "llvm.mul"(%61, %64) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E                 %66 = "llvm.ptrtoint"(%44) : (!llvm.ptr<1>) -> i64
E                 %67 = "llvm.mul"(%65, %13) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E                 %68 = "llvm.sdiv"(%67, %14) : (i64, i64) -> i64
E                 %69 = "llvm.mul"(%68, %14) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E                 %70 = "llvm.icmp"(%67, %69) <{predicate = 1 : i64}> : (i64, i64) -> i1
E                 %71 = "llvm.icmp"(%67, %10) <{predicate = 2 : i64}> : (i64, i64) -> i1
E                 %72 = "llvm.icmp"(%71, %9) <{predicate = 1 : i64}> : (i1, i1) -> i1
E                 %73 = "llvm.and"(%70, %72) : (i1, i1) -> i1
E                 %74 = "llvm.add"(%68, %8) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E                 %75 = "llvm.select"(%73, %74, %68) <{fastmathFlags = #llvm.fastmath<none>}> : (i1, i64, i64) -> i64
E                 %76 = "llvm.add"(%66, %75) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
E                 %77 = "llvm.inttoptr"(%76) : (i64) -> !llvm.ptr<1>
E                 %78 = "llvm.extractvalue"(%53) <{position = array<i64: 0, 0>}> : (!llvm.struct<(struct<(i32, i32, i32, i32)>, struct<(i32, i32, i32)>)>) -> i32
E

...

TRUNCATED for length

drisspg avatar Oct 07 '25 18:10 drisspg