iree icon indicating copy to clipboard operation
iree copied to clipboard

Improvements to stablehlo.fft

Open pstarkcdpr opened this issue 1 month ago • 7 comments

Request description

Currently, lowering stablehlo.fft has a number of issues, including:

  1. Only RFFT is supported by the compiler, even though FFT, IFFT, and IRFFT are accepted by StableHLO.
  2. The compiler does not intentionally fail FFT, IFFT, and IRFFT -- there are still compiler errors, but they happen more by accident than by design.
  3. Large RFFT operations fail due to an error with readonly tensors (see issue 22473).
  4. Only 1D FFTs are supported.
  5. If a multidimensional FFT is requested (e.g. stablehlo.fft %0, type = RFFT, length = [128, 128]), a 1D FFT will be done on the last dimension, which is not the expected output.

We need to use FFTs for some of our models, so we've made some improvements to the current state. I've made a branch with these improvements here: https://github.com/pstarkcdpr/iree/tree/fft_1d

We addressed some, but not all, of the above issues:

  1. Implemented the remaining FFT operations (FFT, IFFT, and IRFFT).
  2. No need to fail now!
  3. Changed the way that the all-zero imaginary input tensor is created to avoid it being marked as readonly, which fixes this issue.
  4. Multidimensional FFTs are still not implemented.
  5. Cause a compile error if FFT length is > 1. Note that you can continue to do a 1D FFT on the last dimension of a multidimensional tensor, but you still can't do a multidimensional FFT. It just fails now instead of silently giving the wrong answer.

I've never contributed to this project before so I'm not entirely certain of the process. I can make a PR from my fork. Internally, I also have a Python script that I'm using to test the FFTs. It compares the results to Numpy (within epsilon), as well as tests round trips within IREE (input == ifft(fft(input))). I'm not sure how IREE's testing works, so I don't know how to integrate something like this into the project.

Also, there is still one issue. If you do not use all the output from the FFTs, then you can get a 'func.func' op exceeded stack allocation limit error when the input size is large enough. This happens most frequently with IRFFT since it throws away the imaginary output immediately, but can also happen with the other operations, like IFFT if you only need the real part. For small input sizes, this isn't an issue. For larger sizes, we have been working around it by using the outputs in non-trivial ways so they don't get optimized into readonly tensors. I don't believe this this is an issue with the lowering from StableHLO however, but more of an issue with linalgext.fft. See issue https://github.com/iree-org/iree/issues/22776.

What component(s) does this issue relate to?

Compiler

Additional context

No response

pstarkcdpr avatar Nov 18 '25 23:11 pstarkcdpr

Thanks @pstarkcdpr . Its great that you were able to use FFT. It was something that we were looking at a long time ago, but havent paid that much attention to it recently. If you have some patches that you want to send out for review to land in tree, we can definitely look into those. The things you fixed seem reasonable enough as a starting point.

With respect to adding tests. For in-tree tests we typically add them as MLIR only files like this https://github.com/iree-org/iree/blob/main/tests/e2e/stablehlo_ops/fft.mlir . This is just the simplest fft test that we have that is just checking functionality. If you want to have larger tests we will need to think of another mechanism to add those. For example

  1. We just added a new test-suite for testing torch ops https://github.com/iree-org/iree-test-suites/tree/main/torch_ops . We will add to that more, but these are basically pytests that accept a json file to run the test. These get then pulled into IREE's CI for testing.
  2. We have a lot more tests from ONNX that was setup a while ago https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops . THey also use pytest, but they are setup slightly differently.

We dont have anything that tests e2e from stablehlo (apart from the basic in-tree tests that I pointed you to above)

Also, there is still one issue. If you do not use all the output from the FFTs, then you can get a 'func.func' op exceeded stack allocation limit error when the input size is large enough. This happens most frequently with IRFFT since it throws away the imaginary output immediately, but can also happen with the other operations, like IFFT if you only need the real part. For small input sizes, this isn't an issue. For larger sizes, we have been working around it by using the outputs in non-trivial ways so they don't get optimized into readonly tensors. I don't believe this this is an issue with the lowering from StableHLO however, but more of an issue with linalgext.fft. See issue https://github.com/iree-org/iree/issues/22473.

Yeah, this has been an issue. Typically we solve it by vectorization. If you vectorize correctly, then we will never have stack allocations, and for most of our performant cases we vectorize. For fft, because we only did a very basic work, we dont vectorize. If that were added, it would probably make the problem go away.

One way to get around that problem is to use --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false. This is really a work-around for the missing vectorization. When I get sometime I can look deeper, cause if the result is really not used it SHOULD get DCE-ed, but that might not be happening here.

MaheshRavishankar avatar Nov 21 '25 21:11 MaheshRavishankar

Thanks @MaheshRavishankar. I should be able adapt my tests to use the same mechanism as the existing FFT tests (check.expect_almost_eq_const). I'll do that and then make a PR with my changes.

Is it worth checking expected failure cases? Multidimensional FFTs aren't supported, so I was testing that the compile failed in this case. Not sure if the in-tree tests support that though.

What would be involved in vectorization? Is there an example where it was done previously to another piece of code? linalg_ext.fft looks to be relatively straightforward (especially since we always pass in the coefficients), and we're using it in a performance-sensitive part of the model, so I'm curious how hard it would be to vectorize.

pstarkcdpr avatar Nov 21 '25 22:11 pstarkcdpr

Is it worth checking expected failure cases? Multidimensional FFTs aren't supported, so I was testing that the compile failed in this case. Not sure if the in-tree tests support that though.

Probably not. Just out of curiosity though. I think you could support multi-dimensional ffts, as 1D FFTs along different dimensions?

What would be involved in vectorization? Is there an example where it was done previously to another piece of code? linalg_ext.fft looks to be relatively straightforward (especially since we always pass in the coefficients), and we're using it in a performance-sensitive part of the model, so I'm curious how hard it would be to vectorize.

If you know how to vectorize it, it should be relatively easy. You can take a look at examples here https://github.com/iree-org/iree/blob/df3d0762b6120e3740f82bbfe28e0130cf668c9b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp#L197 .

There is also specific passes you can write like https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/VectorizeIREELinalgExtOps.cpp that you can use

MaheshRavishankar avatar Nov 21 '25 22:11 MaheshRavishankar

Thanks. I doubt that I know how to vectorize it since I've never looked at any of this before, but I'm still curious to look. Thanks for the pointers.

Yes, multidimensional FFTs are easy in principle. I think that you just do a 1D FFT on the innermost dimension. Then do a 1D FFT of the result on the next dimension, etc. The outer loop should be easy enough to generate, the trouble is that linalg_ext.fft doesn't actually take the dimension as a parameter, it always just does the innermost one.

From the docs

iree_linalg_ext.fft (LinalgExt::FftOp)
Apply 1D FFT to innermost dim.

Luckily, we don't need multidimensional FFTs in our immediate use-case. We considered it, which is why I tested them, but didn't end up needing them. So I just added the compile error and moved on.

pstarkcdpr avatar Nov 21 '25 22:11 pstarkcdpr

I started to put together the e2e tests for FFTs. My regular tests work on LLVM and Vulkan, but converting them to an MLIR test leads to a compile error. For example:

func.func @fft_1d() {
  %input = util.unfoldable_constant dense<[(0.21740782,0.059113156), (0.778621,0.6152795), (0.6581351,0.72274965), (0.85420084,0.5449456), (0.8268271,0.059995476), (0.1314825,0.99552083), (0.07452258,0.97213143), (0.72783166,0.64232296), (0.21783765,0.077186435), (0.2372208,0.9866753), (0.76789004,0.7888627), (0.7672063,0.11202169), (0.20171618,0.71090525), (0.5671336,0.5172962), (0.5653311,0.13171488), (0.6019718,0.22600187), (0.7669061,0.99419916), (0.7296996,0.19616151), (0.49405196,0.23033695), (0.91676044,0.14108129), (0.25893882,0.5168747), (0.0022013448,0.45859325), (0.6937525,0.12294605), (0.14645115,0.8247891), (0.6886818,0.58247185), (0.15621811,0.4682121), (0.7700238,0.4526145), (0.8244305,0.7325929), (0.004420892,0.5293484), (0.0894202,0.36284673), (0.297367,0.76520026), (0.7746221,0.8835885)]> : tensor<32xcomplex<f32>>
  %0 = stablehlo.fft %input, type = FFT, length = [32] : (tensor<32xcomplex<f32>>) -> tensor<32xcomplex<f32>>
  %1 = stablehlo.real %0 : (tensor<32xcomplex<f32>>) -> tensor<32xf32>
  %2 = stablehlo.imag %0 : (tensor<32xcomplex<f32>>) -> tensor<32xf32>
  check.expect_almost_eq_const(%1, dense<[15.809282, 0.21465153, 1.181782, -0.60642904, 0.013650537, -0.9517257, 0.43036938, 0.98342496, -0.6450963, -2.5143266, -0.7090554, -1.3734092, -0.39120948, 0.28289393, -0.8894674, -0.40649498, -0.801661, -0.7713141, -2.8925278, -0.5026806, 1.5893525, -3.5212724, 1.3363503, 2.5882864, -1.6315789, -0.8786793, 0.077445, 0.4787432, 1.1839279, -0.29849604, 2.0874598, -1.5151438]> : tensor<32xf32>) : tensor<32xf32>
  check.expect_almost_eq_const(%2, dense<[16.424582, 1.0421104, -1.6126406, -1.8765254, -2.0437074, -0.34451416, 1.5531765, 0.13303947, 2.2650156, -4.2676077, -1.4725999, -1.0876011, -1.3087398, 1.2196933, 1.8778068, -3.174349, -0.9912782, -0.66518277, 0.6316352, -2.4745326, -0.28285468, 0.21321717, 2.3605132, -4.4211693, -3.57794, -1.9363232, 0.5097058, -0.54355776, 3.218689, 1.0246719, -0.6983651, 2.1972544]> : tensor<32xf32>) : tensor<32xf32>
  return
}

fails with

iree-org/iree/tests/e2e/stablehlo_ops/fft.mlir:33:12: error: 'util.buffer.store' op operand #0 must be index or integer or floating-point, but got 'complex<f32>'
  %input = stablehlo.complex %input_real, %input_imag : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xcomplex<f32>>

My regular tests also fail to compile with the vmvx backend (but LLVM and Vulkan are fine). Any suggestions on how to resolve this?

pstarkcdpr avatar Dec 02 '25 18:12 pstarkcdpr

Actually, is the above error a problem with the vmvx backend? I made a simple test that had nothing to do with FFTs but used complex numbers and get the same error.

func.func @complex_extract() {
  %input = util.unfoldable_constant dense<[(0.1,0.2), (0.3,0.4), (0.5,0.6), (0.7,0.8)]> : tensor<4xcomplex<f32>>
  %1 = stablehlo.real %input : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
  %2 = stablehlo.imag %input : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
  check.expect_almost_eq_const(%1, dense<[0.1,0.3,0.5,0.7]> : tensor<4xf32>) : tensor<4xf32>
  check.expect_almost_eq_const(%2, dense<[0.2,0.4,0.6,0.8]> : tensor<4xf32>) : tensor<4xf32>
  return
}

Should something like the above be part of the regular StableHLO e2e tests anyways? I can see that there's a test for complex numbers in complex.mlir, but if I use --compile-to=input on that, everything gets converted to f32 right away and no complex operations happen in the lower dialects. That's not the case with the above program. Should I log an issue about vmvx not supporting complex numbers? Am I stuck on making a PR with my FFT changes because the tests won't work in vmvx?

pstarkcdpr avatar Dec 02 '25 19:12 pstarkcdpr

Okay, I added new e2e tests in a new mlir file and excluded it on vmvx.

I also excluded it on Vulkan because there is sometimes a compile error, depending on the target. If I compile with --iree-hal-target-device=vulkan --iree-vulkan-target=sm_75, then everything compiles and runs fine. If I omit the Vulkan target and just use --iree-hal-target-device=vulkan (which is basically what the e2e tests are doing), then I get

error: failed to legalize unresolved materialization from ('i64') to ('vector<2xi32>') that remained live after conversion

Unless there's a reason to do otherwise, I'll make a PR tomorrow, and then we can figure out if other issues should be logged as a result of this.

pstarkcdpr avatar Dec 03 '25 01:12 pstarkcdpr