torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Support convolution with `valid` padding.

Open sahas3 opened this issue 1 year ago • 1 comments

Convolution created with valid padding produces the aten.convolution op in the following fashion:

module {
  func.func @main(%arg0: !torch.vtensor<[1,64,57],f32>) -> !torch.vtensor<[1,64,57],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %false = torch.constant.bool false
    %int1 = torch.constant.int 1
    %0 = torch.vtensor.literal(dense<0.536443591> : tensor<1xf32>) : !torch.vtensor<[1],f32>
    %1 = torch.vtensor.literal(dense<-7.486820e-03> : tensor<1x1x1x1xf32>) : !torch.vtensor<[1,1,1,1],f32>
    %int0 = torch.constant.int 0
    %2 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[1,64,57],f32>, !torch.int -> !torch.vtensor<[1,1,64,57],f32>
    %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %4 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %5 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %6 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %7 = torch.aten.convolution %2, %1, %0, %3, %4, %5, %false, %6, %int1 : !torch.vtensor<[1,1,64,57],f32>, !torch.vtensor<[1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,64,57],f32>
    %8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1,1,64,57],f32>, !torch.int -> !torch.vtensor<[1,64,57],f32>
    return %8 : !torch.vtensor<[1,64,57],f32>
  }
}

Note that the padding input to aten.convolution is 1-element whereas the lowerings expect them to be same as number of spatial dims in the input. This results in hitting the assertion in https://github.com/sahas3/torch-mlir/blob/dc7a1ff7d9134758128a637dca976f72c2366e59/lib/Conversion/TorchToLinalg/Utils.cpp#L78 for the TorchToLinalg pass. The failure modes for lowering to tosa and stablehlo are different but stems from the same root cause.

sahas3 avatar Oct 18 '24 15:10 sahas3