torch-mlir
torch-mlir copied to clipboard
aten.pad with mode = reflect is lowered wrong (numeric mismatch)
Example IR
func.func @torch.aten.pad.reflect(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> {
%int0 = torch.constant.int 0
%float0.000000e00 = torch.constant.float 0.000000e+00
%1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int
%pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>
%str = torch.constant.str "reflect"
%ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.tensor<[4],f32>
return %ret : !torch.tensor<[4],f32>
}
This is because the decomposition doesn't check the mode and unconditionally lowers to ConstantPad: https://github.com/llvm/torch-mlir/blob/919b599ebe57b1402b1aca21fa54799cc1e0cd91/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L5864
This can be triggered via ONNX.Pad with mode = reflect.