Implement missing stablehlo.fft operations
See issue #22695. This change does the following:
- Causes a compile error when the fft length has more than 1 dimension (this was never supported but previously gave the wrong answer silently).
- Fixes an issue with RFFT where the imaginary part (a constant zero tensor) could cause a large stack allocation and fail to compile. Fixed by creating the zero imaginary tensor as part of the input shuffling step.
- Implements the remaining FFT operations: FFT, IFFT, IRFFT.
- Changes the function name
rewriteFfttorewriteRfftsince it only handles RFFT. The functionrewriteFftis now used for the full FFT. Also addsrewriteIrfft. - Updates Torch to use the new interface (i.e. changes calling
rewriteRffttorewriteFft). - Updates the Torch tests to reflect changes in RFFT code gen.
- Updates the StableHLO tests to reflect changes in RFFT code gen.
- Adds StableHLO tests for FFT, IFFT, and IRFFT.
- Adds e2e tests for
stablehlo.fft. Note that these are disabled in VMVX due to apparent issues with complex numbers. I mention this in the issue. Also disabled in Vulkan (again mentioned in the issue), but it's worth noting that the FFTs do compile and run when usingsm_75, but fail to compile with the Vulkan target is omitted. I haven't tried other Vulkan targets.
Thanks for the change. This is a relatively big change and seems to do a lot of things at once. I dont want to add extra burden for you to contribute this to the project, so not going to hold you to it, but it is usually faster to merge small PR with targeted changes. Also easy to find missing test coverage that way. If you can "relatively easily" break this PR into smaller PRs that would be better. Ill still put this on my review stack though.
Thanks @MaheshRavishankar. The actual FFT implementation is one large change (RewriteFft.cpp) and a few minor ones. I don't think that can't really be separated out easily. Fixing the tests that break as a result probably needs to be done at the same time. Adding new tests could be a different PR though, if that helps. I just assumed that tests were required when adding new functionality, so made a point of finishing those before creating the PR 😄.
Ok sounds good. Let me review this and see how it goes
Thanks @MaheshRavishankar. I'm very new to making changes in this project, so I wouldn't be surprised if there are things that could be done better.
The changes in RewriteFft.cpp still follows the same original algorithm, but I needed to refactor it to add the additional operations. Other than a bug fix, it still generates the same code as before for RFFT. If it helps, for when you get a chance to review, here's a high-level overview of the algorithm.
-
FFT: (new) complex input and output. This is the canonical implementation and the rest are variations.
- Shuffle and split input: FFT expects the inputs to be shuffled into bit-reversed order and have separate real and imaginary tensors. I do both of these at the same time in a single
linalg.generic. - Run FFT: For each power-of-2, call
linalg_ext.fftpassing in generated coefficients asarith.constants. - Complex output: Combine the real and imaginary parts back together into a complex tensor.
- Shuffle and split input: FFT expects the inputs to be shuffled into bit-reversed order and have separate real and imaginary tensors. I do both of these at the same time in a single
-
IFFT: (new) complex input and output. Same as FFT with different coefficients and scaled by
1/N.- Shuffle and split input: Same as FFT but additionally scales by
1/N. FFTs are linear, so scaling the input or output has the same result. I figured that it was easiest to scale here while we're processing the input anyways instead of adding an additional step. - Run FFT: Same as FFT but with different coefficients.
- Complex output: Same as FFT
- Shuffle and split input: Same as FFT but additionally scales by
-
RFFT: (existed before but slightly modified) real input of size
Nto complex output of sizeN/2+1. Same as FFT but with real input (imaginary part is zero) and truncated.- Shuffle and generate: Shuffle input like with FFT, but no need to split since the input is not complex. We need the imaginary tensor of all zeroes though. Originally this was done with an
arith.constant, but that had an issue with large vector sizes since it was readonly. I tried anempty+fill, but that didn't help. Instead, I'm generating the tensor of zeroes as part of the shuffle loop. It's a bit of a workaround and maybe there's a better way, but it works. - Run FFT: Same as FFT.
- Truncate:
tensor.extract_sliceto sizeN/2+1 - Complex output: Same as FFT
- Shuffle and generate: Shuffle input like with FFT, but no need to split since the input is not complex. We need the imaginary tensor of all zeroes though. Originally this was done with an
-
IRFFT: (new) complex input of size
N/2+1to real output of sizeN. Same as IFFT but need to expand the input back to sizeNand ignore the imaginary output.- Expand input: Make a tensor of size
N/2-1that's mirrored and conjugated using alinalg.generic, thentensor.concatit to the input to make a tensor of sizeN/2. - Call IFFT.
- Keep only the real tensor and ignore the imaginary one.
- Expand input: Make a tensor of size
I hope this helps. Sorry for the big change. We needed IRFFT and that kind of necessitated implementing the whole family.
There was a CI failure due to a mistake I made with an excluded file in a BUILD.bazel. Should be fixed now. Not sure if there is a way to test Bazel locally.
Fixed the pre-commit formatting changes. It's my first PR in this repo so I had no idea that existed. I had been trying to stay within 80 character line limit by hand. This is much easier.