triton icon indicating copy to clipboard operation
triton copied to clipboard

Layout conversion error on H100

Open calebthomas259 opened this issue 1 year ago • 2 comments

Hello,

My modified flash attention kernel gives me the following error when I run it on a H100 GPU, even though the kernel works fine on A100 and RTX 3060:

python: ../../../lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
Aborted (core dumped)

I've reduced the kernel down to the following minimal example, which crashes on the H100 with the same error, but runs successfully on my RTX 3060:

import torch
import triton
import triton.language as tl

@triton.jit
def fwd_kernel(
    out_LHD,
    stride_o_l,
    stride_o_h,
    stride_o_d,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
):

    # current thread block index
    cur_head = tl.program_id(0)

    # data
    x = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float16)
    y = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)

    # calculate
    p = tl.dot(y, x)
    alpha = tl.zeros([BLOCK_M], dtype=tl.float16)
    out = alpha[:, None] * p + p

    # store output
    offs_o = (
        tl.arange(0, BLOCK_M)[:, None] * stride_o_l
        + cur_head * stride_o_h
        + tl.arange(0, BLOCK_D)[None, :] * stride_o_d
    )
    tl.store(out_LHD + offs_o, out)


def main():

    # parameters
    X_len = 84
    H = 8
    D = 32
    BLOCK_M = 64
    BLOCK_N = 64
    device = torch.device("cuda:0")
    torch.manual_seed(999)

    # setup and launch kernel
    out = torch.empty((X_len, H, D), device=device, dtype=torch.float16)
    fwd_kernel[(H,)](
        out,
        out.stride(0),
        out.stride(1),
        out.stride(2),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_D=D,
        num_warps=8,
    )

if __name__ == "__main__":
    main()

The H100 system is LambdaLabs Ubuntu, with these software versions installed using conda:

  • torchtriton 2.3.1
  • python 3.11.6
  • pytorch 2.3.1
  • pytorch-cuda 12.1

Thanks for the help!

calebthomas259 avatar Jul 29 '24 22:07 calebthomas259

Apparently this is a known issue with Hopper architecture https://github.com/triton-lang/triton/pull/2627

calebthomas259 avatar Jul 31 '24 15:07 calebthomas259

In case anyone else has a similar problem, I was able to successfully work around the issue by removing all num_warps = 8 autotune configurations from my flash attention kernel

calebthomas259 avatar Aug 02 '24 15:08 calebthomas259