torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[Linalg] Bring back onnx AveragePool padding asymmetric support

Open AmosLewis opened this issue 1 year ago • 1 comments

Follow up of https://github.com/llvm/torch-mlir/pull/3235 by @zjgarvey

https://github.com/llvm/torch-mlir/commit/ae6f5e8251db09b03adc81fb4a9c0f1f4f87a7ae

AmosLewis avatar Jun 13 '24 03:06 AmosLewis

Base on following Zach comment, We convert it to draft leave this issue unsolved and prioritize the real model related issue request before Jun 30.

I'm double checking the math and it's taking a bit longer than expected. I think the suggestion I made of where to change the code is not completely correct, and this might take a bit more work than expected to remedy.

The way it is currently set up, with ih0 = oh*dh - lowpadH:

  1. left endpoint without pad = max(oh*dh - lowpadH , 0) (correct, since this will shift by the number of included pad elements on the left side of the kernel window used in the oh computation).

  2. As far as I'm aware, ih1 is always equal to ih0 + kH since this would never be greater than Hin + highpadH. (checked using both onnx and pytorch definitions for Hout, and the fact that the largest oh is Hout - 1).

  3. Therefore, right endpoint without pad = min(oh*dh - lowpadH + kH, Hin), which is equivalent to min(oh*dh +kH, Hin + lowpadH) - lowpadH, but should actually be min(oh*dh + kH, Hin + highpadH) - lowpadH in order to correctly exclude the high padding values.

An alternative which would work better:

Use ih0 = oh*dh and ih1 = oh*dh + kH so that ih0,ih1 represent the left and right positions, in the padded input tensor, of the kernel window used for computing at oh.

Then:

  1. Compute right endpoint without pad = max(ih0, lowPad)
  2. Compute right endpoint without pad = min(ih1, Hin + highPad)

Note: these endpoints literally correspond to the kernel window as it sits inside the padded input tensor.

I wouldn't concern yourself with the countIncludePad==true case in this linalg generic since this is handled properly in the original code.


I'm not sure if this is very high priority right now or not. I have an additional concern about the fact that the result shape of the onnx op in the asymmetric case could possibly be inconsistent with the shape inference result of the pytorch op. Before investing too much additional time, it might be a good idea to add a lit test for an exaggeratedly asymmetric op and see.

AmosLewis avatar Jun 14 '24 01:06 AmosLewis