burn icon indicating copy to clipboard operation
burn copied to clipboard

FFT Module (like torch.fft)

Open discordance opened this issue 1 year ago • 11 comments

Hello burn team! Thank you for all the hard work!

Feature description

It would be incredibly valuable to have would be a Discret Fourier Transform (FFT) module, similar to the torch.fft in PyTorch.

Feature motivation

Having it as a part of the framework would add a fundational building block when working with differentiable Digital Signal Processing (e.g dDSP), Spectral losses, and Audio related tasks.

discordance avatar Sep 08 '23 13:09 discordance

If this feature is added this crate might be helpful: https://github.com/ejmahler/RustFFT

It has very fast SIMD accelerated FFT processing.

TheAndrewJackson avatar Sep 08 '23 17:09 TheAndrewJackson

Since Burn uses different backends, we should we take advantage of this. Doing purely for CPU will have data transfer overhead for GPU backends. I believe @Gadersd has implemented something already using Burn ops. https://github.com/Gadersd/whisper-burn

Perhaps, we can implement a default for CPU and let GPU backends override with more specialized implementation.

antimora avatar Sep 08 '23 18:09 antimora

Feel free to use my FFT implementation, but be aware that it isn’t as general as the one torch uses. There are some padding options and checks that should be added for generality.

On Sep 8, 2023, at 2:48 PM, Dilshod Tadjibaev @.***> wrote:

Since Burn uses different backends, we should we take advantage of this. Doing purely for CPU will have data transfer overhead for GPU backends. I believe @Gadersd https://github.com/Gadersd has implemented something already using Burn ops. https://github.com/Gadersd/whisper-burn https://github.com/Gadersd/whisper-burn Perhaps, we can implement a default for CPU and let GPU backends override with more specialized implementation.

— Reply to this email directly, view it on GitHub https://github.com/burn-rs/burn/issues/788#issuecomment-1712090265, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE7COZVJXGSMGOXL5AJIWATXZNR7HANCNFSM6AAAAAA4QNEY6I. You are receiving this because you were mentioned.

Gadersd avatar Sep 08 '23 19:09 Gadersd

I've started work on a PR for this feature, based on an NdArray / Wgpu op I wrote as a custom backend a few months back. There's a few design decisions to get right. We don't support complex numbers at the moment so the way I see it there's 3 ways to handle complex inputs and outputs for each op:

  1. Add a complex type as an Element - probably the most idiomatic but I'm not sure if this is practical
  2. Require separate real and imaginary tensors
  3. Require a complex dimension of length 2 as the last dimension

Secondly, I'm guessing nobody will be too upset if we don't have a backward pass for FFT? The derivative of the fourier transform of a function is itself complex, and propagating imaginary gradients around the place seems like a bad idea. I'm assuming folks will want FFT to preprocess / extract features rather than as an op that needs to be in the graph.

Thoughts?

TomWyllie avatar Mar 15 '24 08:03 TomWyllie

Having it for the forward pass will cover the majority of use cases. As you mentioned many simply need to extract features and it is better to do using the backend device instead of CPU.

antimora avatar Mar 15 '24 14:03 antimora

The only thing I could suggest is to file issues for the remaining work so others can pick up or refer to the deficiencies.

antimora avatar Mar 15 '24 14:03 antimora

I have a branch for an FFT op just about ready for a pull request :slightly_smiling_face: it's implemented and working for NdArray and WGPU with a WGSL shader, I've done 1D FFT and 1D inverse FFT but will add an issue for supporting 2D soon. Haven't added candle or autodiff. There's a couple outstanding tasks before it's ready to be merged to main though:

  • Rewrite the WGSL kernel into the new JIT GPU representation (@louisfd you still up for this?)
  • Not sure how to register this as an op in the fusion backend. The input + output shape is the same for all the FFT ops which I suppose makes the description of the op simpler - is there an existing op that I can base the implementation on?
  • I've not added support for Torch backend but we can do this down the line too, libtorch does support FFT.

If somebody can make an FFT branch in this repo then I can merge my fork in

TomWyllie avatar Mar 29 '24 20:03 TomWyllie

Hi Tom, I do believe fourier transform would be a nice addition to burn 😄 I made a branch so we can work on this together: https://github.com/tracel-ai/burn/tree/feat/ops/fft

I looked at your wgsl and it won't be very hard to translate to JIT. But before that can you make sure it's thoroughly tested? The translation process is error-prone so I need tests to make sure I don't screw up.

louisfd avatar Mar 30 '24 12:03 louisfd

Sounds good. I have a bunch of tests but I'll check some test cases from other libraries tomorrow to make sure I got all edge cases, there's a lot of different parts to the algorithm.

TomWyllie avatar Mar 31 '24 11:03 TomWyllie

Hello folks! This feature indeed would be pretty helpful -- I was about to experiment using burn to train some audio models and was finding myself reaching for an FFT module. I'm happy to help with such a feature. I saw there exists a feature branch but I wasn't sure what it's current state is. Are there any tasks that can be peeled off on which I can contribute?

One additional note -- it would be wonderful if autodiff functionality were included. Reading the feature branch, it seems this might be pushed off to the future. Differentiating through the fft is indeed used in many ML audio applications (spectral losses, DDSP, etc mentioned in the original issue text). As a user, an implementation not including an autodiff backend would be very surprising. Again, happy to help make this a reality.

jbelanich avatar Sep 17 '24 13:09 jbelanich

I think the right approach would be to provide FFT kernels in CubeCL, we can then bundle them with Burn.

nathanielsimard avatar Sep 24 '24 13:09 nathanielsimard

I think the right approach would be to provide FFT kernels in CubeCL, we can then bundle them with Burn.

That would be great to have FFT also in CubeCL!

As far as I can tell, the larger Rust ecosystem is currently lacking a portable FFT that can also run on the GPU.

For the CPU, there's RustFFT and ndrustfft, a wrapper for RustFFT for complex/real n-dimensional FFTs (using ndarray).

For the GPU, we got crates like vkfft-rs, tch-rs (Rust PyTorch wrapper), and arrayfire-rust (wrapper for ArrayFire).

But it would be great to also have a portable WebGPU FFT, to not have to install huge dependencies and be able to ship stand-alone binaries.

FFT on GPU, especially for large data, is normally much faster than on the CPU (depending on the GPU and CPU of course).

For instance, 3D image data in medical imaging and microscopy can be GBs and one would normally want to run FFT on the GPU then.

Having an efficient, portable (n-dim) FFT (e.g. based on WebGPU), would also enable significantly faster convolutions on large data.

janroden avatar Oct 26 '24 08:10 janroden

I just saw that you already opened this issue (that has FFT as a to-do item) over at the CubeCL repo, so linking it here for reference.

janroden avatar Oct 26 '24 08:10 janroden