torch-mlir
torch-mlir copied to clipboard
[RFC] Grouped 2D convolution op in upstream
Just wanted to get some feedback on implementing a linalg op in upstream that implements similar logic to #858, i.e. the same logic as the normal conv_2d_nchw_fchw op but dividing the input and weight beforehand and then concatenating the results. Is there a more efficient way to implement it at that low a level?
Way to approach this would be to try to write the grouped convolution code using perfectly nested loop nest. From there it should be trivial to write the Linalg op definition. I had looked at this a while ago, but will have to get that back into cache. If you know the code directly, that would help.
I think it is literally just a batch of regular convolutions. There is a wrinkle: the high level op from frontends typically has the channel dimension be a multiple of the number of groups, so we first have to expand the channel dimension into two dimensions -- the actual number of channels, and the number of convolutions being done.
But structurally at the linalg level, I think we just need to add an op just like conv_2d_nchw_fchw but with an extra batch dimension.
I think it is literally just a batch of regular convolutions. There is a wrinkle: the high level op from frontends typically has the channel dimension be a multiple of the number of groups, so we first have to expand the channel dimension into two dimensions -- the actual number of channels, and the number of convolutions being done.
This might have issue though if the channel dimension is dynamic, and the number of channels and convolutions are dynamic. tensor.expand_shape cannot expand a dynamic dim into multiple dynamic dims. Its something that we have known is needed ,but have been kicking the can down the road.
The Pytorch op doesn't require the extra batch dimension, right? It should just be a matter of adding one or two extra nested for loops to the logic, i.e.
for outBatch in range(inputDims[0]):
for weightOffset in range(0, weightChannels, weightStride):
for outChannel in range(weightOffset, weightOffset+weightStride):
for inChannel in range(inputOffset, inputOffset+inputStride):
...
inputOffset += inputStride
the number of groups is kind of a "hidden batch dimension". At the linalg level it cannot be hidden like that. We need to expand C -> [numGroups, CPerGroup]
Hmm. In that case could we just handle the case where the group count is constant? That'd mean going from ? -> [G, ?], and it's how we're already handling it in torch-mlir anyway.
So if I'm understanding correctly, the named op should look something like this, and the main issue is that actually expanding the NCHW/FCHW tensors to fit into the new op isn't something we can currently do.
Yes, I think we would handle a constant group count.
But we still want to special case the depthwise and regular convolution cases. Then add support for the general group counts as a follow on.
+1 to having constant group count. IIUC you might need to do reshapes to expand dimensions, and tensor.expand_shape does not allow you to expand a single dynamic dimension into multiple dynamic dimensions. That might be a blocker, so constant group count is probably what is easiest to support right now.
This is fixed after #858