triton
triton copied to clipboard
Layout conversion error on H100
Hello,
My modified flash attention kernel gives me the following error when I run it on a H100 GPU, even though the kernel works fine on A100 and RTX 3060:
python: ../../../lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
Aborted (core dumped)
I've reduced the kernel down to the following minimal example, which crashes on the H100 with the same error, but runs successfully on my RTX 3060:
import torch
import triton
import triton.language as tl
@triton.jit
def fwd_kernel(
out_LHD,
stride_o_l,
stride_o_h,
stride_o_d,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
# current thread block index
cur_head = tl.program_id(0)
# data
x = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float16)
y = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
# calculate
p = tl.dot(y, x)
alpha = tl.zeros([BLOCK_M], dtype=tl.float16)
out = alpha[:, None] * p + p
# store output
offs_o = (
tl.arange(0, BLOCK_M)[:, None] * stride_o_l
+ cur_head * stride_o_h
+ tl.arange(0, BLOCK_D)[None, :] * stride_o_d
)
tl.store(out_LHD + offs_o, out)
def main():
# parameters
X_len = 84
H = 8
D = 32
BLOCK_M = 64
BLOCK_N = 64
device = torch.device("cuda:0")
torch.manual_seed(999)
# setup and launch kernel
out = torch.empty((X_len, H, D), device=device, dtype=torch.float16)
fwd_kernel[(H,)](
out,
out.stride(0),
out.stride(1),
out.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_D=D,
num_warps=8,
)
if __name__ == "__main__":
main()
The H100 system is LambdaLabs Ubuntu, with these software versions installed using conda:
- torchtriton 2.3.1
- python 3.11.6
- pytorch 2.3.1
- pytorch-cuda 12.1
Thanks for the help!
Apparently this is a known issue with Hopper architecture https://github.com/triton-lang/triton/pull/2627
In case anyone else has a similar problem, I was able to successfully work around the issue by removing all num_warps = 8 autotune configurations from my flash attention kernel