torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Support max_unpool2d lowering

Open jinchen62 opened this issue 1 year ago • 10 comments

Support torch->linalg lowering for max_unpool2d.

Fixes https://github.com/nod-ai/SHARK-ModelDev/issues/718

jinchen62 avatar Sep 25 '24 14:09 jinchen62

@vivekkhandelwal1 It got lowered to torch.aten.max_unpool3d, so it could be lowered to linalg. I added the lit tests for both onnx->torch and torch->linalg. So basically there is no big difference between 2d and 3d, it could be generalized, but I couldn't rename the op because of the pytorch upstream, I attached the related links in commit msg.

jinchen62 avatar Sep 30 '24 08:09 jinchen62

@samutamm

zjgarvey avatar Oct 02 '24 13:10 zjgarvey

@vivekkhandelwal1 It got lowered to torch.aten.max_unpool3d, so it could be lowered to linalg. I added the lit tests for both onnx->torch and torch->linalg. So basically there is no big difference between 2d and 3d, it could be generalized, but I couldn't rename the op because of the pytorch upstream, I attached the related links in commit msg.

Do you mean to say that you're lowering the 4-d input case of Onnx.Unpool to AtenMaxUnpool3d which should have instead lowered to AtenMaxUnpool2d, and handling the 4-d input in the lowering of AtenMaxUnpool3d itself, instead of having it as a separate op?

vivekkhandelwal1 avatar Oct 04 '24 10:10 vivekkhandelwal1

@vivekkhandelwal1 Yes, 2D and 3D max_unpool can be generalized as one op.

jinchen62 avatar Oct 04 '24 15:10 jinchen62

@vivekkhandelwal1 Yes, 2D and 3D max_unpool can be generalized as one op.

That's fine but what you've done in this PR is not correct. You have added the support to handle 2d pooling case in the MaxUnpool3d op which is wrong. Ideally, you should've added the lowering for MaxUnpool2d op, and if there exists an issue related to PyTorch with that, then you can define a new op in https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/TorchOps.td (before this, we have to be sure what's the exact issue with using the TorchMaxUnpool2d op, and can that be fixed in upstream PyTorch), say TorchAtenMaxUnpoolOp, and decompose the Unpool3d and 2d op to this particular op and add the torch->linalg lowering for this op.

vivekkhandelwal1 avatar Oct 04 '24 15:10 vivekkhandelwal1

@vivekkhandelwal1 For sure I could add a separate lowering for 2D, but that would be most of duplicate codes.

jinchen62 avatar Oct 04 '24 16:10 jinchen62

@vivekkhandelwal1 For sure I could add a separate lowering for 2D, but that would be most of duplicate codes. Is it okay?

No, you should not do it in a way that the code is duplicated. Instead take all the common code in a utility or templatize the lowering so that the code can be re-used.

vivekkhandelwal1 avatar Oct 04 '24 16:10 vivekkhandelwal1

@vivekkhandelwal1 Using 3D lowering is also because torch.aten.max_unpool2d misses pads and strides inputs as mentioned here https://github.com/nod-ai/SHARK-ModelDev/issues/764#issuecomment-2258978758. I wonder why we don't pass more info through torch op even kernel_shape so we don't need to calculate the kernel size here https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToLinalg/Pooling.cpp#L664. Do you have any suggestions on how to lower 2D case without pads and strides?

jinchen62 avatar Oct 08 '24 01:10 jinchen62

@vivekkhandelwal1 Using 3D lowering is also because torch.aten.max_unpool2d misses pads and strides inputs as mentioned here nod-ai/SHARK-ModelDev#764 (comment). I wonder why we don't pass more info through torch op even kernel_shape so we don't need to calculate the kernel size here https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToLinalg/Pooling.cpp#L664. Do you have any suggestions on how to lower 2D case without pads and strides?

I think this is actually an issue with the PyTorch definition of the max_unpool2d op. The possible way to fix this is either made the fix in PyTorch upstream or define an op in TorchOps.td and use that for the lowering.

vivekkhandelwal1 avatar Oct 09 '24 10:10 vivekkhandelwal1

need to merge https://github.com/pytorch/pytorch/pull/138805

jinchen62 avatar Oct 24 '24 20:10 jinchen62