triton icon indicating copy to clipboard operation
triton copied to clipboard

Compilation error when dtype=float16, but no error when dtype=float32

Open pyjhzwh opened this issue 2 years ago • 0 comments

Repro(although may not be a minimal one): conv_relu_conv_relu_float16.py, conv_relu_conv_relu_float32.py call() does the forward computation of

torch.nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 192, kernel_size=5, padding=2),
    nn.ReLU(inplace=True),
)

kernel 0 do the layout transformation (NCHW->NHWC), kernel 1 and kernel 2 are conv+relu fused kernel It works fine when dtype=float32, but when dtype=16, it will throw LLVM ERROR: Broken function found, compilation aborted!

pyjhzwh avatar Jul 20 '22 18:07 pyjhzwh