iree icon indicating copy to clipboard operation
iree copied to clipboard

Implement missing stablehlo.fft operations

Open pstarkcdpr opened this issue 3 weeks ago • 6 comments

See issue #22695. This change does the following:

  • Causes a compile error when the fft length has more than 1 dimension (this was never supported but previously gave the wrong answer silently).
  • Fixes an issue with RFFT where the imaginary part (a constant zero tensor) could cause a large stack allocation and fail to compile. Fixed by creating the zero imaginary tensor as part of the input shuffling step.
  • Implements the remaining FFT operations: FFT, IFFT, IRFFT.
  • Changes the function name rewriteFft to rewriteRfft since it only handles RFFT. The function rewriteFft is now used for the full FFT. Also adds rewriteIrfft.
  • Updates Torch to use the new interface (i.e. changes calling rewriteRfft to rewriteFft).
  • Updates the Torch tests to reflect changes in RFFT code gen.
  • Updates the StableHLO tests to reflect changes in RFFT code gen.
  • Adds StableHLO tests for FFT, IFFT, and IRFFT.
  • Adds e2e tests for stablehlo.fft. Note that these are disabled in VMVX due to apparent issues with complex numbers. I mention this in the issue. Also disabled in Vulkan (again mentioned in the issue), but it's worth noting that the FFTs do compile and run when using sm_75, but fail to compile with the Vulkan target is omitted. I haven't tried other Vulkan targets.

pstarkcdpr avatar Dec 04 '25 22:12 pstarkcdpr

Thanks for the change. This is a relatively big change and seems to do a lot of things at once. I dont want to add extra burden for you to contribute this to the project, so not going to hold you to it, but it is usually faster to merge small PR with targeted changes. Also easy to find missing test coverage that way. If you can "relatively easily" break this PR into smaller PRs that would be better. Ill still put this on my review stack though.

MaheshRavishankar avatar Dec 05 '25 02:12 MaheshRavishankar

Thanks @MaheshRavishankar. The actual FFT implementation is one large change (RewriteFft.cpp) and a few minor ones. I don't think that can't really be separated out easily. Fixing the tests that break as a result probably needs to be done at the same time. Adding new tests could be a different PR though, if that helps. I just assumed that tests were required when adding new functionality, so made a point of finishing those before creating the PR 😄.

pstarkcdpr avatar Dec 05 '25 04:12 pstarkcdpr

Ok sounds good. Let me review this and see how it goes

MaheshRavishankar avatar Dec 05 '25 05:12 MaheshRavishankar

Thanks @MaheshRavishankar. I'm very new to making changes in this project, so I wouldn't be surprised if there are things that could be done better.

The changes in RewriteFft.cpp still follows the same original algorithm, but I needed to refactor it to add the additional operations. Other than a bug fix, it still generates the same code as before for RFFT. If it helps, for when you get a chance to review, here's a high-level overview of the algorithm.

  • FFT: (new) complex input and output. This is the canonical implementation and the rest are variations.

    • Shuffle and split input: FFT expects the inputs to be shuffled into bit-reversed order and have separate real and imaginary tensors. I do both of these at the same time in a single linalg.generic.
    • Run FFT: For each power-of-2, call linalg_ext.fft passing in generated coefficients as arith.constants.
    • Complex output: Combine the real and imaginary parts back together into a complex tensor.
  • IFFT: (new) complex input and output. Same as FFT with different coefficients and scaled by 1/N.

    • Shuffle and split input: Same as FFT but additionally scales by 1/N. FFTs are linear, so scaling the input or output has the same result. I figured that it was easiest to scale here while we're processing the input anyways instead of adding an additional step.
    • Run FFT: Same as FFT but with different coefficients.
    • Complex output: Same as FFT
  • RFFT: (existed before but slightly modified) real input of size N to complex output of size N/2+1. Same as FFT but with real input (imaginary part is zero) and truncated.

    • Shuffle and generate: Shuffle input like with FFT, but no need to split since the input is not complex. We need the imaginary tensor of all zeroes though. Originally this was done with an arith.constant, but that had an issue with large vector sizes since it was readonly. I tried an empty + fill, but that didn't help. Instead, I'm generating the tensor of zeroes as part of the shuffle loop. It's a bit of a workaround and maybe there's a better way, but it works.
    • Run FFT: Same as FFT.
    • Truncate: tensor.extract_slice to size N/2+1
    • Complex output: Same as FFT
  • IRFFT: (new) complex input of size N/2+1 to real output of size N. Same as IFFT but need to expand the input back to size N and ignore the imaginary output.

    • Expand input: Make a tensor of size N/2-1 that's mirrored and conjugated using a linalg.generic, then tensor.concat it to the input to make a tensor of size N/2.
    • Call IFFT.
    • Keep only the real tensor and ignore the imaginary one.

I hope this helps. Sorry for the big change. We needed IRFFT and that kind of necessitated implementing the whole family.

pstarkcdpr avatar Dec 05 '25 18:12 pstarkcdpr

There was a CI failure due to a mistake I made with an excluded file in a BUILD.bazel. Should be fixed now. Not sure if there is a way to test Bazel locally.

pstarkcdpr avatar Dec 09 '25 07:12 pstarkcdpr

Fixed the pre-commit formatting changes. It's my first PR in this repo so I had no idea that existed. I had been trying to stay within 80 character line limit by hand. This is much easier.

pstarkcdpr avatar Dec 09 '25 21:12 pstarkcdpr