Transposed Convolution
Proposed changes
This implementation extends conv_general and introduces a bool flag transpose to make use of the existing code path.
- Includes tests and benchmarks.
- Only supports groups=1 now.
- Does not support output padding, yet.
- torch uses (in_channels, kernel_size, ..., out_channels) weight shape for conv_transpose, this implementation uses (out_channels, kernel_size, ..., in_channels). This is up to discussion.
Checklist
Put an x in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
@mlaves this looks really nice!! I think the main question is what the purpose of the ConvolutionTranspose primitive is? I didn't look to carefully, but it looks like we should just use the Convolution primitive and send it the right arguments.
Other than that I'll review the rest of the API in some detail shortly but looks great so far.
@mlaves you are planning to come back to this? Would be great to land it.
@mlaves you are planning to come back to this? Would be great to land it.
Sorry for the delay, I'll finish that in the next couple of days.
@awni I added a transpose flag to the existing Convolution primitive and removed the initially introduced primitive that is now obsolete. I also rebased onto main.
I remove the tranpose flag and attempted to resolve the incorrect gradient w.r.t. the weight when the flip argument is True. I don't think all the gradients are correct yet / our tests are not comprehensive. But this is a good improvement over where we were.
We should add more comprehensive tests for backward particularly with dilations / flip / groups. But we can leave that for a follow on PR (along with any potential fixes).
I remove the tranpose flag and attempted to resolve the incorrect gradient w.r.t. the
weightwhen theflipargument isTrue. I don't think all the gradients are correct yet / our tests are not comprehensive. But this is a good improvement over where we were.We should add more comprehensive tests for backward particularly with dilations / flip / groups. But we can leave that for a follow on PR (along with any potential fixes).
Cool, that looks like a nice simplification over my initial implementation. This PR looks good to me now. I can follow up with some more tests as mentioned.