Support max_unpool2d lowering
Support torch->linalg lowering for max_unpool2d.
Fixes https://github.com/nod-ai/SHARK-ModelDev/issues/718
@vivekkhandelwal1 It got lowered to torch.aten.max_unpool3d, so it could be lowered to linalg. I added the lit tests for both onnx->torch and torch->linalg. So basically there is no big difference between 2d and 3d, it could be generalized, but I couldn't rename the op because of the pytorch upstream, I attached the related links in commit msg.
@samutamm
@vivekkhandelwal1 It got lowered to torch.aten.max_unpool3d, so it could be lowered to linalg. I added the lit tests for both onnx->torch and torch->linalg. So basically there is no big difference between 2d and 3d, it could be generalized, but I couldn't rename the op because of the pytorch upstream, I attached the related links in commit msg.
Do you mean to say that you're lowering the 4-d input case of Onnx.Unpool to AtenMaxUnpool3d which should have instead lowered to AtenMaxUnpool2d, and handling the 4-d input in the lowering of AtenMaxUnpool3d itself, instead of having it as a separate op?
@vivekkhandelwal1 Yes, 2D and 3D max_unpool can be generalized as one op.
@vivekkhandelwal1 Yes, 2D and 3D max_unpool can be generalized as one op.
That's fine but what you've done in this PR is not correct. You have added the support to handle 2d pooling case in the MaxUnpool3d op which is wrong. Ideally, you should've added the lowering for MaxUnpool2d op, and if there exists an issue related to PyTorch with that, then you can define a new op in https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/TorchOps.td (before this, we have to be sure what's the exact issue with using the TorchMaxUnpool2d op, and can that be fixed in upstream PyTorch), say TorchAtenMaxUnpoolOp, and decompose the Unpool3d and 2d op to this particular op and add the torch->linalg lowering for this op.
@vivekkhandelwal1 For sure I could add a separate lowering for 2D, but that would be most of duplicate codes.
@vivekkhandelwal1 For sure I could add a separate lowering for 2D, but that would be most of duplicate codes. Is it okay?
No, you should not do it in a way that the code is duplicated. Instead take all the common code in a utility or templatize the lowering so that the code can be re-used.
@vivekkhandelwal1 Using 3D lowering is also because torch.aten.max_unpool2d misses pads and strides inputs as mentioned here https://github.com/nod-ai/SHARK-ModelDev/issues/764#issuecomment-2258978758. I wonder why we don't pass more info through torch op even kernel_shape so we don't need to calculate the kernel size here https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToLinalg/Pooling.cpp#L664. Do you have any suggestions on how to lower 2D case without pads and strides?
@vivekkhandelwal1 Using 3D lowering is also because torch.aten.max_unpool2d misses
padsandstridesinputs as mentioned here nod-ai/SHARK-ModelDev#764 (comment). I wonder why we don't pass more info through torch op evenkernel_shapeso we don't need to calculate the kernel size here https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToLinalg/Pooling.cpp#L664. Do you have any suggestions on how to lower 2D case withoutpadsandstrides?
I think this is actually an issue with the PyTorch definition of the max_unpool2d op. The possible way to fix this is either made the fix in PyTorch upstream or define an op in TorchOps.td and use that for the lowering.
need to merge https://github.com/pytorch/pytorch/pull/138805