[LLVMGPU][ROCm] SDXL int8 fails to compile on gfx90a
Repro
Download the model from https://github.com/nod-ai/sdxl-scripts/blob/main/int8-model/base_ir/punet_07_18.mlir.
Compilation command:
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx90a punet_07_18.mlir -o punet.vmfb
Compile error:
<unknown>:0: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
<unknown>:0: note: see current operation: %83 = "builtin.unrealized_conversion_cast"(%82) : (!llvm.array<1 x array<4 x vector<2xi32>>>) -> vector<1x4x2xi32>
../punet_07_18.mlir:13702:13: error: failed to translate the MLIR LLVM dialect to the native llvm::Module
%4634 = torch.prims.convert_element_type %4633, %int5_805 : !torch.vtensor<[2,4096,640],f32>, !torch.int -> !torch.vtensor<[2,4096,640],f16>
^
This compiles fine for gfx942. I will try to minimize this next.
cc: @MaheshRavishankar
This failure is in the lowering of a (i8,i8)->i32 batch matmul op here is a smaller repro IR
module {
func.func @punet_repro(%11 : tensor<2x4096x640xi8>, %12 : tensor<2x640x640xi8>, %13 : tensor<640xi32>, %14 : tensor<640xf32>) -> tensor<2x4096x640xi32> {
%c0_i32 = arith.constant 0 : i32
%16 = tensor.empty() : tensor<2x4096x640xi32>
%17 = linalg.fill ins(%c0_i32 : i32) outs(%16 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%18 = linalg.batch_matmul_transpose_b ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%17 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
return %18 : tensor<2x4096x640xi32>
}
}
This is going down the SIMT pipeline as shown in dump here and leaves behind
%157 = builtin.unrealized_conversion_cast %156 : vector<1x4x2xi32> to !llvm.array<1 x array<4 x vector<2xi32>>>
That fails in the llvm translation
I have a WIP PR https://github.com/iree-org/iree/pull/18433 that solves seems to solve the issue (at least the batchmatmul compiles), I will clean it up and add some e2e tests to it and then put it up for review later today.