burn
burn copied to clipboard
FFT Module (like torch.fft)
Hello burn team! Thank you for all the hard work!
Feature description
It would be incredibly valuable to have would be a Discret Fourier Transform (FFT) module, similar to the torch.fft in PyTorch.
Feature motivation
Having it as a part of the framework would add a fundational building block when working with differentiable Digital Signal Processing (e.g dDSP), Spectral losses, and Audio related tasks.
If this feature is added this crate might be helpful: https://github.com/ejmahler/RustFFT
It has very fast SIMD accelerated FFT processing.
Since Burn uses different backends, we should we take advantage of this. Doing purely for CPU will have data transfer overhead for GPU backends. I believe @Gadersd has implemented something already using Burn ops. https://github.com/Gadersd/whisper-burn
Perhaps, we can implement a default for CPU and let GPU backends override with more specialized implementation.
Feel free to use my FFT implementation, but be aware that it isn’t as general as the one torch uses. There are some padding options and checks that should be added for generality.
On Sep 8, 2023, at 2:48 PM, Dilshod Tadjibaev @.***> wrote:
Since Burn uses different backends, we should we take advantage of this. Doing purely for CPU will have data transfer overhead for GPU backends. I believe @Gadersd https://github.com/Gadersd has implemented something already using Burn ops. https://github.com/Gadersd/whisper-burn https://github.com/Gadersd/whisper-burn Perhaps, we can implement a default for CPU and let GPU backends override with more specialized implementation.
— Reply to this email directly, view it on GitHub https://github.com/burn-rs/burn/issues/788#issuecomment-1712090265, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE7COZVJXGSMGOXL5AJIWATXZNR7HANCNFSM6AAAAAA4QNEY6I. You are receiving this because you were mentioned.
I've started work on a PR for this feature, based on an NdArray / Wgpu op I wrote as a custom backend a few months back. There's a few design decisions to get right. We don't support complex numbers at the moment so the way I see it there's 3 ways to handle complex inputs and outputs for each op:
- Add a complex type as an
Element
- probably the most idiomatic but I'm not sure if this is practical - Require separate real and imaginary tensors
- Require a complex dimension of length 2 as the last dimension
Secondly, I'm guessing nobody will be too upset if we don't have a backward pass for FFT? The derivative of the fourier transform of a function is itself complex, and propagating imaginary gradients around the place seems like a bad idea. I'm assuming folks will want FFT to preprocess / extract features rather than as an op that needs to be in the graph.
Thoughts?
Having it for the forward pass will cover the majority of use cases. As you mentioned many simply need to extract features and it is better to do using the backend device instead of CPU.
The only thing I could suggest is to file issues for the remaining work so others can pick up or refer to the deficiencies.
I have a branch for an FFT op just about ready for a pull request :slightly_smiling_face: it's implemented and working for NdArray and WGPU with a WGSL shader, I've done 1D FFT and 1D inverse FFT but will add an issue for supporting 2D soon. Haven't added candle or autodiff. There's a couple outstanding tasks before it's ready to be merged to main though:
- Rewrite the WGSL kernel into the new JIT GPU representation (@louisfd you still up for this?)
- Not sure how to register this as an op in the fusion backend. The input + output shape is the same for all the FFT ops which I suppose makes the description of the op simpler - is there an existing op that I can base the implementation on?
- I've not added support for Torch backend but we can do this down the line too,
libtorch
does support FFT.
If somebody can make an FFT branch in this repo then I can merge my fork in
Hi Tom, I do believe fourier transform would be a nice addition to burn 😄 I made a branch so we can work on this together: https://github.com/tracel-ai/burn/tree/feat/ops/fft
I looked at your wgsl and it won't be very hard to translate to JIT. But before that can you make sure it's thoroughly tested? The translation process is error-prone so I need tests to make sure I don't screw up.
Sounds good. I have a bunch of tests but I'll check some test cases from other libraries tomorrow to make sure I got all edge cases, there's a lot of different parts to the algorithm.
Hello folks! This feature indeed would be pretty helpful -- I was about to experiment using burn to train some audio models and was finding myself reaching for an FFT module. I'm happy to help with such a feature. I saw there exists a feature branch but I wasn't sure what it's current state is. Are there any tasks that can be peeled off on which I can contribute?
One additional note -- it would be wonderful if autodiff functionality were included. Reading the feature branch, it seems this might be pushed off to the future. Differentiating through the fft is indeed used in many ML audio applications (spectral losses, DDSP, etc mentioned in the original issue text). As a user, an implementation not including an autodiff backend would be very surprising. Again, happy to help make this a reality.
I think the right approach would be to provide FFT kernels in CubeCL, we can then bundle them with Burn.
I think the right approach would be to provide FFT kernels in CubeCL, we can then bundle them with Burn.
That would be great to have FFT also in CubeCL!
As far as I can tell, the larger Rust ecosystem is currently lacking a portable FFT that can also run on the GPU.
For the CPU, there's RustFFT and ndrustfft, a wrapper for RustFFT for complex/real n-dimensional FFTs (using ndarray).
For the GPU, we got crates like vkfft-rs, tch-rs (Rust PyTorch wrapper), and arrayfire-rust (wrapper for ArrayFire).
But it would be great to also have a portable WebGPU FFT, to not have to install huge dependencies and be able to ship stand-alone binaries.
FFT on GPU, especially for large data, is normally much faster than on the CPU (depending on the GPU and CPU of course).
For instance, 3D image data in medical imaging and microscopy can be GBs and one would normally want to run FFT on the GPU then.
Having an efficient, portable (n-dim) FFT (e.g. based on WebGPU), would also enable significantly faster convolutions on large data.
I just saw that you already opened this issue (that has FFT as a to-do item) over at the CubeCL repo, so linking it here for reference.