escnn
escnn copied to clipboard
Generalize the Fourier transform API
This PR is a proposal to refactor the Fourier transform API, with the goal of making it easier to incorporate Fourier transforms into other modules. Here are the two specific use-cases I was trying to facilitate:
-
A Fourier max pooling layer, as discussed in #65. This would be very similar to the existing
FourierPointwiseclass, except after the nonlinearity, there would also be max-pooling and Gaussian blurring steps. -
An IFT as an output layer. It's probably not clear what I mean by that, and it's possible that I only needed such a layer because I overlooked some easier way of doing things, so I want to take some time to explain the problem I was trying to solve. My goal was to reimplement [Doersch2016], but in 3D and with equivariance. The idea in [Doersch2016] is to create a self-supervised training protocol by taking two nearby crops of an image, and having the model predict the location of the second relative to the first. There would only be a handful of possible relative locations, e.g. above, below, right, and left (for 2D images). I implemented this by having the final layer of my model be a single spectral regular representation (of the quotient space $S^2 = SO(3) / SO(2)$, because the two crops cannot rotate relative to each other), then performing an IFT with each grid point corresponding to one of the possible relative locations. This results in values for each location that can be interpreted as logits. And if the input rotates, so do the logits. To bring this back to the PR at hand, the important point is that this application requires being able to perform an IFT without a subsequent FT.
I think that the best way to support these two use-cases, and possibly others that I haven't thought of, is to create separate FT and IFT modules. That's what the proposed API does. Here are the specific classes involved:
-
InverseFourierTransform: A pytorch module where the input is a tensor with a spectral regular representation, and the output is a tensor of signal values sampled on a grid. -
FourierTransform: The opposite ofInverseFourierTransform. This module also provides the option to prepare the FT matrix with more irreps than will ultimately be output. -
FourierFieldType: Most equivariant modules accept input/output field types as arguments, butFourierPointwiseis an exception. It acceptsgspace,channels, andirrepsarguments, and uses them to create a compatible field type under the hood. This API is a bit awkward to begin with, but it's worse when the same arguments need to be passed to two different modules.To bring the Fourier API in line with all the other modules, I created
FourierFieldType. This is a subclass ofFieldTypethat only allows spectral regular representations (possibly with respect to a quotient space). The IFT and FT modules require this field type (and check for it). Other modules are agnostic to it. -
GridTensor: A class that wraps the output of an IFT and the input to an FT. It's similar in concept toGeometricTensor, except that instead of keeping track of the representation associated with a tensor, it keeps track of the grid. This lets the FT module check that it's compatible with the input it receives, and (for GNNs) restore thecoordsattribute.
Using these classes, I reimplemented the FourierPointwise class in a way that I believe to be 100% backwards-compatible. The new implementation also removes hundreds of lines of code that were duplicated between FourierPointwise and QuotientFourierPointwise. Below is a simplified FourierRelu version of this class, just to give a sense for how it works:
class FourierRelu(EquivariantModule):
def __init__(self, in_type: FourierFieldType, grid: List[GroupElement]):
super().__init__()
self.in_type = self.out_type = in_type
self.ift = InverseFourierTransform(self.in_type, grid)
self.ft = FourierTransform(grid, self.out_type)
def forward(self, x_hat: GeometricTensor) -> GeometricTensor:
assert x_hat.type == self.in_type
x: GridTensor = self.ift(x_hat)
F.relu_(x.tensor)
return self.ft(x)
Minor comments:
-
This PR isn't ready to be merged yet. I haven't updated the documentation, and although all the existing tests pass, I want to write some new tests as well. But before I spend a lot of time on those tasks, I want to know if there's any interest in merging this.
-
I haven't implemented the aforementioned Fourier max pooling module yet. But if there's interest, I could add that to the PR as well.