triton
triton copied to clipboard
Compilation error when dtype=float16, but no error when dtype=float32
Repro(although may not be a minimal one): conv_relu_conv_relu_float16.py, conv_relu_conv_relu_float32.py
call()
does the forward computation of
torch.nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
)
kernel 0 do the layout transformation (NCHW->NHWC), kernel 1 and kernel 2 are conv+relu fused kernel
It works fine when dtype=float32, but when dtype=16, it will throw LLVM ERROR: Broken function found, compilation aborted!