torch-mlir
torch-mlir copied to clipboard
Support convolution with `valid` padding.
Convolution created with valid padding produces the aten.convolution op in the following fashion:
module {
func.func @main(%arg0: !torch.vtensor<[1,64,57],f32>) -> !torch.vtensor<[1,64,57],f32> attributes {torch.assume_strict_symbolic_shapes} {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<0.536443591> : tensor<1xf32>) : !torch.vtensor<[1],f32>
%1 = torch.vtensor.literal(dense<-7.486820e-03> : tensor<1x1x1x1xf32>) : !torch.vtensor<[1,1,1,1],f32>
%int0 = torch.constant.int 0
%2 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[1,64,57],f32>, !torch.int -> !torch.vtensor<[1,1,64,57],f32>
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%5 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%6 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%7 = torch.aten.convolution %2, %1, %0, %3, %4, %5, %false, %6, %int1 : !torch.vtensor<[1,1,64,57],f32>, !torch.vtensor<[1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,64,57],f32>
%8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1,1,64,57],f32>, !torch.int -> !torch.vtensor<[1,64,57],f32>
return %8 : !torch.vtensor<[1,64,57],f32>
}
}
Note that the padding input to aten.convolution is 1-element whereas the lowerings expect them to be same as number of spatial dims in the input. This results in hitting the assertion in https://github.com/sahas3/torch-mlir/blob/dc7a1ff7d9134758128a637dca976f72c2366e59/lib/Conversion/TorchToLinalg/Utils.cpp#L78 for the TorchToLinalg pass. The failure modes for lowering to tosa and stablehlo are different but stems from the same root cause.